diff --git a/cmd/checkheader/main.go b/cmd/checkheader/main.go index 13232dc..002e512 100644 --- a/cmd/checkheader/main.go +++ b/cmd/checkheader/main.go @@ -117,7 +117,7 @@ func generate(header string) (string, error) { } continue } - if strings.HasPrefix(l, "enum {") { + if strings.HasPrefix(l, "enum {") || strings.HasPrefix(l, "typedef enum {") { enum = true continue } diff --git a/openssl/ecdh.go b/openssl/ecdh.go new file mode 100644 index 0000000..a13c411 --- /dev/null +++ b/openssl/ecdh.go @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl + +// #include "goopenssl.h" +import "C" +import ( + "errors" + "runtime" +) + +type PublicKeyECDH struct { + curve string + key C.GO_EC_POINT_PTR + bytes []byte +} + +func (k *PublicKeyECDH) finalize() { + C.go_openssl_EC_POINT_free(k.key) +} + +type PrivateKeyECDH struct { + curve string + key C.GO_EC_KEY_PTR +} + +func (k *PrivateKeyECDH) finalize() { + C.go_openssl_EC_KEY_free(k.key) +} + +func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { + if len(bytes) < 1 { + return nil, errors.New("NewPublicKeyECDH: missing key") + } + + nid, err := curveNID(curve) + if err != nil { + return nil, err + } + + group := C.go_openssl_EC_GROUP_new_by_curve_name(nid) + if group == nil { + return nil, newOpenSSLError("EC_GROUP_new_by_curve_name") + } + defer C.go_openssl_EC_GROUP_free(group) + key := C.go_openssl_EC_POINT_new(group) + if key == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + ok := C.go_openssl_EC_POINT_oct2point(group, key, base(bytes), C.size_t(len(bytes)), nil) != 0 + if !ok { + C.go_openssl_EC_POINT_free(key) + return nil, errors.New("point not on curve") + } + + k := &PublicKeyECDH{curve, key, append([]byte(nil), bytes...)} + runtime.SetFinalizer(k, (*PublicKeyECDH).finalize) + return k, nil +} + +func (k *PublicKeyECDH) Bytes() []byte { return k.bytes } + +func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) { + nid, err := curveNID(curve) + if err != nil { + return nil, err + } + key := C.go_openssl_EC_KEY_new_by_curve_name(nid) + if key == nil { + return nil, newOpenSSLError("EC_KEY_new_by_curve_name") + } + b := bytesToBN(bytes) + ok := b != nil && C.go_openssl_EC_KEY_set_private_key(key, b) != 0 + if b != nil { + C.go_openssl_BN_free(b) + } + if !ok { + C.go_openssl_EC_KEY_free(key) + return nil, newOpenSSLError("EC_KEY_set_private_key") + } + k := &PrivateKeyECDH{curve, key} + runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) + return k, nil +} + +func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { + defer runtime.KeepAlive(k) + + group := C.go_openssl_EC_KEY_get0_group(k.key) + if group == nil { + return nil, newOpenSSLError("EC_KEY_get0_group") + } + kbig := C.go_openssl_EC_KEY_get0_private_key(k.key) + if kbig == nil { + return nil, newOpenSSLError("EC_KEY_get0_private_key") + } + pt := C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 { + C.go_openssl_EC_POINT_free(pt) + return nil, newOpenSSLError("EC_POINT_mul") + } + bytes, err := pointBytesECDH(k.curve, group, pt) + if err != nil { + C.go_openssl_EC_POINT_free(pt) + return nil, err + } + pub := &PublicKeyECDH{k.curve, pt, bytes} + // Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive. + runtime.SetFinalizer(pub, (*PublicKeyECDH).finalize) + return pub, nil +} + +func pointBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) { + out := make([]byte, 1+2*curveSize(curve)) + n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(out), C.size_t(len(out)), nil) + if int(n) != len(out) { + return nil, newOpenSSLError("EC_POINT_point2oct") + } + return out, nil +} + +func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) { + group := C.go_openssl_EC_KEY_get0_group(priv.key) + if group == nil { + return nil, newOpenSSLError("EC_KEY_get0_group") + } + privBig := C.go_openssl_EC_KEY_get0_private_key(priv.key) + if privBig == nil { + return nil, newOpenSSLError("EC_KEY_get0_private_key") + } + pt := C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + defer C.go_openssl_EC_POINT_free(pt) + if C.go_openssl_EC_POINT_mul(group, pt, nil, pub.key, privBig, nil) == 0 { + return nil, newOpenSSLError("EC_POINT_mul") + } + out, err := xCoordBytesECDH(priv.curve, group, pt) + if err != nil { + return nil, err + } + return out, nil +} + +func xCoordBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) { + big := C.go_openssl_BN_new() + defer C.go_openssl_BN_free(big) + if C.go_openssl_EC_POINT_get_affine_coordinates_GFp(group, pt, big, nil, nil) == 0 { + return nil, newOpenSSLError("EC_POINT_get_affine_coordinates_GFp") + } + return bigBytesECDH(curve, big) +} + +func bigBytesECDH(curve string, big C.GO_BIGNUM_PTR) ([]byte, error) { + out := make([]byte, curveSize(curve)) + if C.go_openssl_BN_bn2binpad(big, base(out), C.int(len(out))) == 0 { + return nil, newOpenSSLError("BN_bn2binpad") + } + return out, nil +} + +func curveSize(curve string) int { + switch curve { + default: + panic("openssl: unknown curve " + curve) + case "P-256": + return 256 / 8 + case "P-384": + return 384 / 8 + case "P-521": + return (521 + 7) / 8 + } +} + +func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) { + pkey, err := generateEVPPKey(C.GO_EVP_PKEY_EC, 0, curve) + if err != nil { + return nil, nil, err + } + defer C.go_openssl_EVP_PKEY_free(pkey) + key := C.go_openssl_EVP_PKEY_get1_EC_KEY(pkey) + if key == nil { + return nil, nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY") + } + group := C.go_openssl_EC_KEY_get0_group(key) + if group == nil { + C.go_openssl_EC_KEY_free(key) + return nil, nil, newOpenSSLError("EC_KEY_get0_group") + } + b := C.go_openssl_EC_KEY_get0_private_key(key) + if b == nil { + C.go_openssl_EC_KEY_free(key) + return nil, nil, newOpenSSLError("EC_KEY_get0_private_key") + } + bytes, err := bigBytesECDH(curve, b) + if err != nil { + C.go_openssl_EC_KEY_free(key) + return nil, nil, err + } + + k := &PrivateKeyECDH{curve, key} + runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) + return k, bytes, nil +} diff --git a/openssl/ecdh_test.go b/openssl/ecdh_test.go new file mode 100644 index 0000000..ed8178e --- /dev/null +++ b/openssl/ecdh_test.go @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +//go:build linux && !android +// +build linux,!android + +package openssl_test + +import ( + "bytes" + "encoding/hex" + "testing" + + "github.com/microsoft/go-crypto-openssl/openssl" +) + +func TestECDH(t *testing.T) { + for _, tt := range []string{"P-256", "P-384", "P-521"} { + t.Run(tt, func(t *testing.T) { + name := tt + aliceKey, alicPrivBytes, err := openssl.GenerateKeyECDH(name) + if err != nil { + t.Fatal(err) + } + bobKey, _, err := openssl.GenerateKeyECDH(name) + if err != nil { + t.Fatal(err) + } + + alicePubKeyFromPriv, err := aliceKey.PublicKey() + if err != nil { + t.Fatal(err) + } + alicePubBytes := alicePubKeyFromPriv.Bytes() + want := len(alicPrivBytes) + var got int + if tt == "X25519" { + got = len(alicePubBytes) + } else { + got = (len(alicePubBytes) - 1) / 2 // subtract encoding prefix and divide by the number of components + } + if want != got { + t.Fatalf("public key size mismatch: want: %v, got: %v", want, got) + } + alicePubKey, err := openssl.NewPublicKeyECDH(name, alicePubBytes) + if err != nil { + t.Error(err) + } + + bobPubKeyFromPriv, err := bobKey.PublicKey() + if err != nil { + t.Error(err) + } + _, err = openssl.NewPublicKeyECDH(name, bobPubKeyFromPriv.Bytes()) + if err != nil { + t.Error(err) + } + + bobSecret, err := openssl.ECDH(bobKey, alicePubKey) + if err != nil { + t.Fatal(err) + } + aliceSecret, err := openssl.ECDH(aliceKey, bobPubKeyFromPriv) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(bobSecret, aliceSecret) { + t.Error("two ECDH computations came out different") + } + }) + } +} + +// The following vectors have been copied from +// https://github.com/golang/go/blob/bb0d8297d76cb578baad8fa1485565d9acf44cc5/src/crypto/ecdh/ecdh_test.go. + +var ecdhvectors = []struct { + Name string + PrivateKey, PublicKey string + PeerPublicKey string + SharedSecret string +}{ + // NIST vectors from CAVS 14.1, ECC CDH Primitive (SP800-56A). + { + Name: "P-256", + PrivateKey: "7d7dc5f71eb29ddaf80d6214632eeae03d9058af1fb6d22ed80badb62bc1a534", + PublicKey: "04ead218590119e8876b29146ff89ca61770c4edbbf97d38ce385ed281d8a6b230" + + "28af61281fd35e2fa7002523acc85a429cb06ee6648325389f59edfce1405141", + PeerPublicKey: "04700c48f77f56584c5cc632ca65640db91b6bacce3a4df6b42ce7cc838833d287" + + "db71e509e3fd9b060ddb20ba5c51dcc5948d46fbf640dfe0441782cab85fa4ac", + SharedSecret: "46fc62106420ff012e54a434fbdd2d25ccc5852060561e68040dd7778997bd7b", + }, + { + Name: "P-384", + PrivateKey: "3cc3122a68f0d95027ad38c067916ba0eb8c38894d22e1b15618b6818a661774ad463b205da88cf699ab4d43c9cf98a1", + PublicKey: "049803807f2f6d2fd966cdd0290bd410c0190352fbec7ff6247de1302df86f25d34fe4a97bef60cff548355c015dbb3e5f" + + "ba26ca69ec2f5b5d9dad20cc9da711383a9dbe34ea3fa5a2af75b46502629ad54dd8b7d73a8abb06a3a3be47d650cc99", + PeerPublicKey: "04a7c76b970c3b5fe8b05d2838ae04ab47697b9eaf52e764592efda27fe7513272734466b400091adbf2d68c58e0c50066" + + "ac68f19f2e1cb879aed43a9969b91a0839c4c38a49749b661efedf243451915ed0905a32b060992b468c64766fc8437a", + SharedSecret: "5f9d29dc5e31a163060356213669c8ce132e22f57c9a04f40ba7fcead493b457e5621e766c40a2e3d4d6a04b25e533f1", + }, + // For some reason all field elements in the test vector (both scalars and + // base field elements), but not the shared secret output, have two extra + // leading zero bytes (which in big-endian are irrelevant). Removed here. + { + Name: "P-521", + PrivateKey: "017eecc07ab4b329068fba65e56a1f8890aa935e57134ae0ffcce802735151f4eac6564f6ee9974c5e6887a1fefee5743ae2241bfeb95d5ce31ddcb6f9edb4d6fc47", + PublicKey: "0400602f9d0cf9e526b29e22381c203c48a886c2b0673033366314f1ffbcba240ba42f4ef38a76174635f91e6b4ed34275eb01c8467d05ca80315bf1a7bbd945f550a5" + + "01b7c85f26f5d4b2d7355cf6b02117659943762b6d1db5ab4f1dbc44ce7b2946eb6c7de342962893fd387d1b73d7a8672d1f236961170b7eb3579953ee5cdc88cd2d", + PeerPublicKey: "0400685a48e86c79f0f0875f7bc18d25eb5fc8c0b07e5da4f4370f3a9490340854334b1e1b87fa395464c60626124a4e70d0f785601d37c09870ebf176666877a2046d" + + "01ba52c56fc8776d9e8f5db4f0cc27636d0b741bbe05400697942e80b739884a83bde99e0f6716939e632bc8986fa18dccd443a348b6c3e522497955a4f3c302f676", + SharedSecret: "005fc70477c3e63bc3954bd0df3ea0d1f41ee21746ed95fc5e1fdf90930d5e136672d72cc770742d1711c3c3a4c334a0ad9759436a4d3c5bf6e74b9578fac148c831", + }, +} + +func TestVectors(t *testing.T) { + for _, tt := range ecdhvectors { + t.Run(tt.Name, func(t *testing.T) { + key, err := openssl.NewPrivateKeyECDH(tt.Name, hexDecode(t, tt.PrivateKey)) + if err != nil { + t.Fatal(err) + } + pub, err := key.PublicKey() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pub.Bytes(), hexDecode(t, tt.PublicKey)) { + t.Error("public key derived from the private key does not match") + } + peer, err := openssl.NewPublicKeyECDH(tt.Name, hexDecode(t, tt.PeerPublicKey)) + if err != nil { + t.Fatal(err) + } + secret, err := openssl.ECDH(key, peer) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(secret, hexDecode(t, tt.SharedSecret)) { + t.Error("shared secret does not match") + } + }) + } +} + +func hexDecode(t *testing.T, s string) []byte { + b, err := hex.DecodeString(s) + if err != nil { + t.Fatal("invalid hex string:", s) + } + return b +} diff --git a/openssl/ecdsa.go b/openssl/ecdsa.go index e5a3e03..a9edf8d 100644 --- a/openssl/ecdsa.go +++ b/openssl/ecdsa.go @@ -102,7 +102,7 @@ func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error) bx = bigToBN(X) by = bigToBN(Y) if bx == nil || by == nil { - return nil, newOpenSSLError("BN_bin2bn failed") + return nil, newOpenSSLError("BN_lebin2bn failed") } if key = C.go_openssl_EC_KEY_new_by_curve_name(nid); key == nil { return nil, newOpenSSLError("EC_KEY_new_by_curve_name failed") @@ -113,7 +113,7 @@ func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error) if D != nil { bd := bigToBN(D) if bd == nil { - return nil, newOpenSSLError("BN_bin2bn failed") + return nil, newOpenSSLError("BN_lebin2bn failed") } defer C.go_openssl_BN_free(bd) if C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 { diff --git a/openssl/openssl.go b/openssl/openssl.go index 95f3e88..e3e13d7 100644 --- a/openssl/openssl.go +++ b/openssl/openssl.go @@ -251,6 +251,13 @@ func wbase(b BigInt) *C.uchar { return (*C.uchar)(unsafe.Pointer(&b[0])) } +func bytesToBN(x []byte) C.GO_BIGNUM_PTR { + if len(x) == 0 { + return nil + } + return C.go_openssl_BN_bin2bn(base(x), C.int(len(x)), nil) +} + func bigToBN(x BigInt) C.GO_BIGNUM_PTR { if len(x) == 0 { return nil diff --git a/openssl/openssl_funcs.h b/openssl/openssl_funcs.h index d6fc72a..46f592c 100644 --- a/openssl/openssl_funcs.h +++ b/openssl/openssl_funcs.h @@ -38,9 +38,13 @@ enum { // #include enum { - GO_EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID = 0x1001 + GO_EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID = 0x1001, }; +typedef enum { + GO_POINT_CONVERSION_UNCOMPRESSED = 4, +} point_conversion_form_t; + // #include enum { GO_NID_X9_62_prime256v1 = 415, @@ -204,16 +208,21 @@ DEFINEFUNC(void, BN_clear_free, (GO_BIGNUM_PTR arg0), (arg0)) \ DEFINEFUNC(int, BN_num_bits, (const GO_BIGNUM_PTR arg0), (arg0)) \ DEFINEFUNC(GO_BIGNUM_PTR, BN_bin2bn, (const unsigned char *arg0, int arg1, GO_BIGNUM_PTR arg2), (arg0, arg1, arg2)) \ DEFINEFUNC(int, BN_bn2bin, (const GO_BIGNUM_PTR arg0, unsigned char *arg1), (arg0, arg1)) \ -/* bn_lebin2bn and bn_bn2lebinpad are not exported in any OpenSSL 1.0.2, but they exist. */ \ +/* bn_lebin2bn, bn_bn2lebinpad and BN_bn2binpad are not exported in any OpenSSL 1.0.2, but they exist. */ \ /*check:from=1.1.0*/ DEFINEFUNC_RENAMED_1_1(GO_BIGNUM_PTR, BN_lebin2bn, bn_lebin2bn, (const unsigned char *s, int len, GO_BIGNUM_PTR ret), (s, len, ret)) \ /*check:from=1.1.0*/ DEFINEFUNC_RENAMED_1_1(int, BN_bn2lebinpad, bn_bn2lebinpad, (const GO_BIGNUM_PTR a, unsigned char *to, int tolen), (a, to, tolen)) \ +/*check:from=1.1.0*/ DEFINEFUNC_RENAMED_1_1(int, BN_bn2binpad, bn_bn2binpad, (const GO_BIGNUM_PTR a, unsigned char *to, int tolen), (a, to, tolen)) \ DEFINEFUNC(void, EC_GROUP_free, (GO_EC_GROUP_PTR arg0), (arg0)) \ DEFINEFUNC(GO_EC_POINT_PTR, EC_POINT_new, (const GO_EC_GROUP_PTR arg0), (arg0)) \ DEFINEFUNC(void, EC_POINT_free, (GO_EC_POINT_PTR arg0), (arg0)) \ DEFINEFUNC(int, EC_POINT_get_affine_coordinates_GFp, (const GO_EC_GROUP_PTR arg0, const GO_EC_POINT_PTR arg1, GO_BIGNUM_PTR arg2, GO_BIGNUM_PTR arg3, GO_BN_CTX_PTR arg4), (arg0, arg1, arg2, arg3, arg4)) \ +DEFINEFUNC(size_t, EC_POINT_point2oct, (const GO_EC_GROUP_PTR group, const GO_EC_POINT_PTR p, point_conversion_form_t form, unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (group, p, form, buf, len, ctx)) \ +DEFINEFUNC(int, EC_POINT_oct2point, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR p, const unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (group, p, buf, len, ctx)) \ +DEFINEFUNC(int, EC_POINT_mul, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR r, const GO_BIGNUM_PTR n, const GO_EC_POINT_PTR q, const GO_BIGNUM_PTR m, GO_BN_CTX_PTR ctx), (group, r, n, q, m, ctx)) \ DEFINEFUNC(GO_EC_KEY_PTR, EC_KEY_new_by_curve_name, (int arg0), (arg0)) \ DEFINEFUNC(int, EC_KEY_set_public_key_affine_coordinates, (GO_EC_KEY_PTR key, GO_BIGNUM_PTR x, GO_BIGNUM_PTR y), (key, x, y)) \ DEFINEFUNC(void, EC_KEY_free, (GO_EC_KEY_PTR arg0), (arg0)) \ +DEFINEFUNC(GO_EC_GROUP_PTR, EC_GROUP_new_by_curve_name, (int nid), (nid)) \ DEFINEFUNC(const GO_EC_GROUP_PTR, EC_KEY_get0_group, (const GO_EC_KEY_PTR arg0), (arg0)) \ DEFINEFUNC(int, EC_KEY_set_private_key, (GO_EC_KEY_PTR arg0, const GO_BIGNUM_PTR arg1), (arg0, arg1)) \ DEFINEFUNC(const GO_BIGNUM_PTR, EC_KEY_get0_private_key, (const GO_EC_KEY_PTR arg0), (arg0)) \