diff --git a/openssl/ecdh.go b/openssl/ecdh.go index a13c411..84ae334 100644 --- a/openssl/ecdh.go +++ b/openssl/ecdh.go @@ -14,50 +14,65 @@ import ( ) type PublicKeyECDH struct { - curve string - key C.GO_EC_POINT_PTR + _pkey C.GO_EVP_PKEY_PTR bytes []byte } func (k *PublicKeyECDH) finalize() { - C.go_openssl_EC_POINT_free(k.key) + C.go_openssl_EVP_PKEY_free(k._pkey) } type PrivateKeyECDH struct { - curve string - key C.GO_EC_KEY_PTR + _pkey C.GO_EVP_PKEY_PTR } func (k *PrivateKeyECDH) finalize() { - C.go_openssl_EC_KEY_free(k.key) + C.go_openssl_EVP_PKEY_free(k._pkey) } func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) { if len(bytes) < 1 { return nil, errors.New("NewPublicKeyECDH: missing key") } - nid, err := curveNID(curve) if err != nil { return nil, err } - - group := C.go_openssl_EC_GROUP_new_by_curve_name(nid) - if group == nil { - return nil, newOpenSSLError("EC_GROUP_new_by_curve_name") - } - defer C.go_openssl_EC_GROUP_free(group) - key := C.go_openssl_EC_POINT_new(group) + key := C.go_openssl_EC_KEY_new_by_curve_name(nid) if key == nil { - return nil, newOpenSSLError("EC_POINT_new") + return nil, newOpenSSLError("EC_KEY_new_by_curve_name") } - ok := C.go_openssl_EC_POINT_oct2point(group, key, base(bytes), C.size_t(len(bytes)), nil) != 0 - if !ok { - C.go_openssl_EC_POINT_free(key) - return nil, errors.New("point not on curve") + var k *PublicKeyECDH + defer func() { + if k == nil { + C.go_openssl_EC_KEY_free(key) + } + }() + if vMajor == 1 && vMinor == 0 { + // EC_KEY_oct2key does not exist on OpenSSL 1.0.2, + // we have to simulate it. + group := C.go_openssl_EC_KEY_get0_group(key) + pt := C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + defer C.go_openssl_EC_POINT_free(pt) + if C.go_openssl_EC_POINT_oct2point(group, pt, base(bytes), C.size_t(len(bytes)), nil) != 1 { + return nil, errors.New("point not on curve") + } + if C.go_openssl_EC_KEY_set_public_key(key, pt) != 1 { + return nil, newOpenSSLError("EC_KEY_set_public_key") + } + } else { + if C.go_openssl_EC_KEY_oct2key(key, base(bytes), C.size_t(len(bytes)), nil) != 1 { + return nil, newOpenSSLError("EC_KEY_oct2key") + } + } + pkey, err := newEVPPKEY(key) + if err != nil { + return nil, err } - - k := &PublicKeyECDH{curve, key, append([]byte(nil), bytes...)} + k = &PublicKeyECDH{pkey, append([]byte(nil), bytes...)} runtime.SetFinalizer(k, (*PublicKeyECDH).finalize) return k, nil } @@ -69,144 +84,128 @@ func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) { if err != nil { return nil, err } + b := bytesToBN(bytes) + if b == nil { + return nil, newOpenSSLError("BN_bin2bn") + } + defer C.go_openssl_BN_free(b) key := C.go_openssl_EC_KEY_new_by_curve_name(nid) if key == nil { return nil, newOpenSSLError("EC_KEY_new_by_curve_name") } - b := bytesToBN(bytes) - ok := b != nil && C.go_openssl_EC_KEY_set_private_key(key, b) != 0 - if b != nil { - C.go_openssl_BN_free(b) - } - if !ok { - C.go_openssl_EC_KEY_free(key) + var pkey C.GO_EVP_PKEY_PTR + defer func() { + if pkey == nil { + C.go_openssl_EC_KEY_free(key) + } + }() + if C.go_openssl_EC_KEY_set_private_key(key, b) != 1 { return nil, newOpenSSLError("EC_KEY_set_private_key") } - k := &PrivateKeyECDH{curve, key} + pkey, err = newEVPPKEY(key) + if err != nil { + return nil, err + } + k := &PrivateKeyECDH{pkey} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) return k, nil } func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) { defer runtime.KeepAlive(k) - - group := C.go_openssl_EC_KEY_get0_group(k.key) + key := C.go_openssl_EVP_PKEY_get1_EC_KEY(k._pkey) + if key == nil { + return nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY") + } + defer C.go_openssl_EC_KEY_free(key) + group := C.go_openssl_EC_KEY_get0_group(key) if group == nil { return nil, newOpenSSLError("EC_KEY_get0_group") } - kbig := C.go_openssl_EC_KEY_get0_private_key(k.key) - if kbig == nil { - return nil, newOpenSSLError("EC_KEY_get0_private_key") - } - pt := C.go_openssl_EC_POINT_new(group) + pt := C.go_openssl_EC_KEY_get0_public_key(key) if pt == nil { - return nil, newOpenSSLError("EC_POINT_new") - } - if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 { - C.go_openssl_EC_POINT_free(pt) - return nil, newOpenSSLError("EC_POINT_mul") + // The public key will be nil if k has been generated using + // NewPrivateKeyECDH instead of GenerateKeyECDH. + // + // OpenSSL does not expose any method to generate the public + // key from the private key [1], so we have to calculate it here + // https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206 + pt = C.go_openssl_EC_POINT_new(group) + if pt == nil { + return nil, newOpenSSLError("EC_POINT_new") + } + defer C.go_openssl_EC_POINT_free(pt) + kbig := C.go_openssl_EC_KEY_get0_private_key(key) + if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 { + return nil, newOpenSSLError("EC_POINT_mul") + } + } + n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, nil, 0, nil) + if n == 0 { + return nil, newOpenSSLError("EC_POINT_point2oct") } - bytes, err := pointBytesECDH(k.curve, group, pt) - if err != nil { - C.go_openssl_EC_POINT_free(pt) - return nil, err + bytes := make([]byte, n) + n = C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(bytes), C.size_t(len(bytes)), nil) + if int(n) != len(bytes) { + return nil, newOpenSSLError("EC_POINT_point2oct") } - pub := &PublicKeyECDH{k.curve, pt, bytes} + pub := &PublicKeyECDH{k._pkey, bytes} // Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive. runtime.SetFinalizer(pub, (*PublicKeyECDH).finalize) return pub, nil } -func pointBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) { - out := make([]byte, 1+2*curveSize(curve)) - n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(out), C.size_t(len(out)), nil) - if int(n) != len(out) { - return nil, newOpenSSLError("EC_POINT_point2oct") - } - return out, nil -} - func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) { - group := C.go_openssl_EC_KEY_get0_group(priv.key) - if group == nil { - return nil, newOpenSSLError("EC_KEY_get0_group") + defer runtime.KeepAlive(priv) + defer runtime.KeepAlive(pub) + ctx := C.go_openssl_EVP_PKEY_CTX_new(priv._pkey, nil) + if ctx == nil { + return nil, newOpenSSLError("EVP_PKEY_CTX_new") } - privBig := C.go_openssl_EC_KEY_get0_private_key(priv.key) - if privBig == nil { - return nil, newOpenSSLError("EC_KEY_get0_private_key") + defer C.go_openssl_EVP_PKEY_CTX_free(ctx) + if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive_init") } - pt := C.go_openssl_EC_POINT_new(group) - if pt == nil { - return nil, newOpenSSLError("EC_POINT_new") - } - defer C.go_openssl_EC_POINT_free(pt) - if C.go_openssl_EC_POINT_mul(group, pt, nil, pub.key, privBig, nil) == 0 { - return nil, newOpenSSLError("EC_POINT_mul") - } - out, err := xCoordBytesECDH(priv.curve, group, pt) - if err != nil { - return nil, err + if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive_set_peer") } - return out, nil -} - -func xCoordBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) { - big := C.go_openssl_BN_new() - defer C.go_openssl_BN_free(big) - if C.go_openssl_EC_POINT_get_affine_coordinates_GFp(group, pt, big, nil, nil) == 0 { - return nil, newOpenSSLError("EC_POINT_get_affine_coordinates_GFp") + var outLen C.size_t + if C.go_openssl_EVP_PKEY_derive(ctx, nil, &outLen) != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive_init") } - return bigBytesECDH(curve, big) -} - -func bigBytesECDH(curve string, big C.GO_BIGNUM_PTR) ([]byte, error) { - out := make([]byte, curveSize(curve)) - if C.go_openssl_BN_bn2binpad(big, base(out), C.int(len(out))) == 0 { - return nil, newOpenSSLError("BN_bn2binpad") + out := make([]byte, outLen) + if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 { + return nil, newOpenSSLError("EVP_PKEY_derive_init") } return out, nil } -func curveSize(curve string) int { - switch curve { - default: - panic("openssl: unknown curve " + curve) - case "P-256": - return 256 / 8 - case "P-384": - return 384 / 8 - case "P-521": - return (521 + 7) / 8 - } -} - func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) { pkey, err := generateEVPPKey(C.GO_EVP_PKEY_EC, 0, curve) if err != nil { return nil, nil, err } - defer C.go_openssl_EVP_PKEY_free(pkey) + var k *PrivateKeyECDH + defer func() { + if k == nil { + C.go_openssl_EVP_PKEY_free(pkey) + } + }() key := C.go_openssl_EVP_PKEY_get1_EC_KEY(pkey) if key == nil { return nil, nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY") } - group := C.go_openssl_EC_KEY_get0_group(key) - if group == nil { - C.go_openssl_EC_KEY_free(key) - return nil, nil, newOpenSSLError("EC_KEY_get0_group") - } + defer C.go_openssl_EC_KEY_free(key) b := C.go_openssl_EC_KEY_get0_private_key(key) if b == nil { - C.go_openssl_EC_KEY_free(key) return nil, nil, newOpenSSLError("EC_KEY_get0_private_key") } - bytes, err := bigBytesECDH(curve, b) - if err != nil { - C.go_openssl_EC_KEY_free(key) - return nil, nil, err + bits := C.go_openssl_EVP_PKEY_get_bits(pkey) + out := make([]byte, (bits+7)/8) + if C.go_openssl_BN_bn2binpad(b, base(out), C.int(len(out))) == 0 { + return nil, nil, newOpenSSLError("BN_bn2binpad") } - - k := &PrivateKeyECDH{curve, key} + k = &PrivateKeyECDH{pkey} runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize) - return k, bytes, nil + return k, out, nil } diff --git a/openssl/ecdsa.go b/openssl/ecdsa.go index a9edf8d..de4aa0e 100644 --- a/openssl/ecdsa.go +++ b/openssl/ecdsa.go @@ -11,7 +11,6 @@ import "C" import ( "errors" "runtime" - "unsafe" ) type PrivateKeyECDSA struct { @@ -73,13 +72,12 @@ func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) { return k, nil } -func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error) { - var nid C.int - if nid, err = curveNID(curve); err != nil { +func newECKey(curve string, X, Y, D BigInt) (C.GO_EVP_PKEY_PTR, error) { + nid, err := curveNID(curve) + if err != nil { return nil, err } - var bx, by C.GO_BIGNUM_PTR - var key C.GO_EC_KEY_PTR + var bx, by, bd C.GO_BIGNUM_PTR defer func() { if bx != nil { C.go_openssl_BN_free(bx) @@ -87,44 +85,35 @@ func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error) if by != nil { C.go_openssl_BN_free(by) } - if err != nil { - if key != nil { - C.go_openssl_EC_KEY_free(key) - } - if pkey != nil { - C.go_openssl_EVP_PKEY_free(pkey) - // pkey is a named return, so in case of error - // it have to be cleared before returing. - pkey = nil - } + if bd != nil { + C.go_openssl_BN_free(bd) } }() bx = bigToBN(X) by = bigToBN(Y) - if bx == nil || by == nil { + bd = bigToBN(D) + if bx == nil || by == nil || (D != nil && bd == nil) { return nil, newOpenSSLError("BN_lebin2bn failed") } - if key = C.go_openssl_EC_KEY_new_by_curve_name(nid); key == nil { + key := C.go_openssl_EC_KEY_new_by_curve_name(nid) + if key == nil { return nil, newOpenSSLError("EC_KEY_new_by_curve_name failed") } + var pkey C.GO_EVP_PKEY_PTR + defer func() { + if pkey == nil { + defer C.go_openssl_EC_KEY_free(key) + } + }() if C.go_openssl_EC_KEY_set_public_key_affine_coordinates(key, bx, by) != 1 { return nil, newOpenSSLError("EC_KEY_set_public_key_affine_coordinates failed") } - if D != nil { - bd := bigToBN(D) - if bd == nil { - return nil, newOpenSSLError("BN_lebin2bn failed") - } - defer C.go_openssl_BN_free(bd) - if C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 { - return nil, newOpenSSLError("EC_KEY_set_private_key failed") - } - } - if pkey = C.go_openssl_EVP_PKEY_new(); pkey == nil { - return nil, newOpenSSLError("EVP_PKEY_new failed") + if D != nil && C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 { + return nil, newOpenSSLError("EC_KEY_set_private_key failed") } - if C.go_openssl_EVP_PKEY_assign(pkey, C.GO_EVP_PKEY_EC, (unsafe.Pointer)(key)) != 1 { - return nil, newOpenSSLError("EVP_PKEY_assign failed") + pkey, err = newEVPPKEY(key) + if err != nil { + return nil, err } return pkey, nil } diff --git a/openssl/evpkey.go b/openssl/evpkey.go index 97e39c0..2965d01 100644 --- a/openssl/evpkey.go +++ b/openssl/evpkey.go @@ -311,3 +311,15 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, } return verifyEVP(withKey, padding, nil, nil, saltLen, h, verifyInit, verify, sig, hashed) } + +func newEVPPKEY(key C.GO_EC_KEY_PTR) (C.GO_EVP_PKEY_PTR, error) { + pkey := C.go_openssl_EVP_PKEY_new() + if pkey == nil { + return nil, newOpenSSLError("EVP_PKEY_new failed") + } + if C.go_openssl_EVP_PKEY_assign(pkey, C.GO_EVP_PKEY_EC, (unsafe.Pointer)(key)) != 1 { + C.go_openssl_EVP_PKEY_free(pkey) + return nil, newOpenSSLError("EVP_PKEY_assign failed") + } + return pkey, nil +} diff --git a/openssl/openssl_funcs.h b/openssl/openssl_funcs.h index 46f592c..fe72f90 100644 --- a/openssl/openssl_funcs.h +++ b/openssl/openssl_funcs.h @@ -217,14 +217,15 @@ DEFINEFUNC(GO_EC_POINT_PTR, EC_POINT_new, (const GO_EC_GROUP_PTR arg0), (arg0)) DEFINEFUNC(void, EC_POINT_free, (GO_EC_POINT_PTR arg0), (arg0)) \ DEFINEFUNC(int, EC_POINT_get_affine_coordinates_GFp, (const GO_EC_GROUP_PTR arg0, const GO_EC_POINT_PTR arg1, GO_BIGNUM_PTR arg2, GO_BIGNUM_PTR arg3, GO_BN_CTX_PTR arg4), (arg0, arg1, arg2, arg3, arg4)) \ DEFINEFUNC(size_t, EC_POINT_point2oct, (const GO_EC_GROUP_PTR group, const GO_EC_POINT_PTR p, point_conversion_form_t form, unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (group, p, form, buf, len, ctx)) \ -DEFINEFUNC(int, EC_POINT_oct2point, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR p, const unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (group, p, buf, len, ctx)) \ +DEFINEFUNC_LEGACY_1_0(int, EC_POINT_oct2point, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR p, const unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (group, p, buf, len, ctx)) \ DEFINEFUNC(int, EC_POINT_mul, (const GO_EC_GROUP_PTR group, GO_EC_POINT_PTR r, const GO_BIGNUM_PTR n, const GO_EC_POINT_PTR q, const GO_BIGNUM_PTR m, GO_BN_CTX_PTR ctx), (group, r, n, q, m, ctx)) \ DEFINEFUNC(GO_EC_KEY_PTR, EC_KEY_new_by_curve_name, (int arg0), (arg0)) \ DEFINEFUNC(int, EC_KEY_set_public_key_affine_coordinates, (GO_EC_KEY_PTR key, GO_BIGNUM_PTR x, GO_BIGNUM_PTR y), (key, x, y)) \ +DEFINEFUNC_LEGACY_1_0(int, EC_KEY_set_public_key, (GO_EC_KEY_PTR key, const GO_EC_POINT_PTR pub), (key, pub)) \ DEFINEFUNC(void, EC_KEY_free, (GO_EC_KEY_PTR arg0), (arg0)) \ -DEFINEFUNC(GO_EC_GROUP_PTR, EC_GROUP_new_by_curve_name, (int nid), (nid)) \ DEFINEFUNC(const GO_EC_GROUP_PTR, EC_KEY_get0_group, (const GO_EC_KEY_PTR arg0), (arg0)) \ DEFINEFUNC(int, EC_KEY_set_private_key, (GO_EC_KEY_PTR arg0, const GO_BIGNUM_PTR arg1), (arg0, arg1)) \ +DEFINEFUNC_1_1(int, EC_KEY_oct2key, (GO_EC_KEY_PTR eckey, const unsigned char *buf, size_t len, GO_BN_CTX_PTR ctx), (eckey, buf, len, ctx)) \ DEFINEFUNC(const GO_BIGNUM_PTR, EC_KEY_get0_private_key, (const GO_EC_KEY_PTR arg0), (arg0)) \ DEFINEFUNC(const GO_EC_POINT_PTR, EC_KEY_get0_public_key, (const GO_EC_KEY_PTR arg0), (arg0)) \ DEFINEFUNC(GO_RSA_PTR, RSA_new, (void), ()) \ @@ -255,9 +256,10 @@ DEFINEFUNC(const GO_EVP_CIPHER_PTR, EVP_aes_256_gcm, (void), ()) \ DEFINEFUNC(void, EVP_CIPHER_CTX_free, (GO_EVP_CIPHER_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_CIPHER_CTX_ctrl, (GO_EVP_CIPHER_CTX_PTR ctx, int type, int arg, void *ptr), (ctx, type, arg, ptr)) \ DEFINEFUNC(GO_EVP_PKEY_PTR, EVP_PKEY_new, (void), ()) \ -/* EVP_PKEY_size pkey parameter is const since OpenSSL 1.1.1. */ \ +/* EVP_PKEY_size and EVP_PKEY_get_bits pkey parameter is const since OpenSSL 1.1.1. */ \ /* Exclude it from headercheck tool when using previous OpenSSL versions. */ \ /*check:from=1.1.1*/ DEFINEFUNC_RENAMED_3_0(int, EVP_PKEY_get_size, EVP_PKEY_size, (const GO_EVP_PKEY_PTR pkey), (pkey)) \ +/*check:from=1.1.1*/ DEFINEFUNC_RENAMED_3_0(int, EVP_PKEY_get_bits, EVP_PKEY_bits, (const GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC(void, EVP_PKEY_free, (GO_EVP_PKEY_PTR arg0), (arg0)) \ DEFINEFUNC(GO_EC_KEY_PTR, EVP_PKEY_get1_EC_KEY, (GO_EVP_PKEY_PTR pkey), (pkey)) \ DEFINEFUNC(GO_RSA_PTR, EVP_PKEY_get1_RSA, (GO_EVP_PKEY_PTR pkey), (pkey)) \ @@ -276,6 +278,9 @@ DEFINEFUNC(int, EVP_PKEY_encrypt_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_sign_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_verify_init, (GO_EVP_PKEY_CTX_PTR arg0), (arg0)) \ DEFINEFUNC(int, EVP_PKEY_sign, (GO_EVP_PKEY_CTX_PTR arg0, unsigned char *arg1, size_t *arg2, const unsigned char *arg3, size_t arg4), (arg0, arg1, arg2, arg3, arg4)) \ +DEFINEFUNC(int, EVP_PKEY_derive_init, (GO_EVP_PKEY_CTX_PTR ctx), (ctx)) \ +DEFINEFUNC(int, EVP_PKEY_derive_set_peer, (GO_EVP_PKEY_CTX_PTR ctx, GO_EVP_PKEY_PTR peer), (ctx, peer)) \ +DEFINEFUNC(int, EVP_PKEY_derive, (GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t *keylen), (ctx, key, keylen)) \ DEFINEFUNC_3_0(GO_EVP_MAC_PTR, EVP_MAC_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \ DEFINEFUNC_3_0(GO_EVP_MAC_CTX_PTR, EVP_MAC_CTX_new, (GO_EVP_MAC_PTR arg0), (arg0)) \ DEFINEFUNC_3_0(void, EVP_MAC_CTX_free, (GO_EVP_MAC_CTX_PTR arg0), (arg0)) \