Skip to content

Commit

Permalink
Merge pull request #44 from microsoft/ecdh-evp
Browse files Browse the repository at this point in the history
Implement ECDH using the EVP interface
  • Loading branch information
qmuntal authored Jan 18, 2023
2 parents 4033c5d + 2edda65 commit 5eef080
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 149 deletions.
227 changes: 113 additions & 114 deletions openssl/ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,50 +14,65 @@ import (
)

type PublicKeyECDH struct {
curve string
key C.GO_EC_POINT_PTR
_pkey C.GO_EVP_PKEY_PTR
bytes []byte
}

func (k *PublicKeyECDH) finalize() {
C.go_openssl_EC_POINT_free(k.key)
C.go_openssl_EVP_PKEY_free(k._pkey)
}

type PrivateKeyECDH struct {
curve string
key C.GO_EC_KEY_PTR
_pkey C.GO_EVP_PKEY_PTR
}

func (k *PrivateKeyECDH) finalize() {
C.go_openssl_EC_KEY_free(k.key)
C.go_openssl_EVP_PKEY_free(k._pkey)
}

func NewPublicKeyECDH(curve string, bytes []byte) (*PublicKeyECDH, error) {
if len(bytes) < 1 {
return nil, errors.New("NewPublicKeyECDH: missing key")
}

nid, err := curveNID(curve)
if err != nil {
return nil, err
}

group := C.go_openssl_EC_GROUP_new_by_curve_name(nid)
if group == nil {
return nil, newOpenSSLError("EC_GROUP_new_by_curve_name")
}
defer C.go_openssl_EC_GROUP_free(group)
key := C.go_openssl_EC_POINT_new(group)
key := C.go_openssl_EC_KEY_new_by_curve_name(nid)
if key == nil {
return nil, newOpenSSLError("EC_POINT_new")
return nil, newOpenSSLError("EC_KEY_new_by_curve_name")
}
ok := C.go_openssl_EC_POINT_oct2point(group, key, base(bytes), C.size_t(len(bytes)), nil) != 0
if !ok {
C.go_openssl_EC_POINT_free(key)
return nil, errors.New("point not on curve")
var k *PublicKeyECDH
defer func() {
if k == nil {
C.go_openssl_EC_KEY_free(key)
}
}()
if vMajor == 1 && vMinor == 0 {
// EC_KEY_oct2key does not exist on OpenSSL 1.0.2,
// we have to simulate it.
group := C.go_openssl_EC_KEY_get0_group(key)
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
defer C.go_openssl_EC_POINT_free(pt)
if C.go_openssl_EC_POINT_oct2point(group, pt, base(bytes), C.size_t(len(bytes)), nil) != 1 {
return nil, errors.New("point not on curve")
}
if C.go_openssl_EC_KEY_set_public_key(key, pt) != 1 {
return nil, newOpenSSLError("EC_KEY_set_public_key")
}
} else {
if C.go_openssl_EC_KEY_oct2key(key, base(bytes), C.size_t(len(bytes)), nil) != 1 {
return nil, newOpenSSLError("EC_KEY_oct2key")
}
}
pkey, err := newEVPPKEY(key)
if err != nil {
return nil, err
}

k := &PublicKeyECDH{curve, key, append([]byte(nil), bytes...)}
k = &PublicKeyECDH{pkey, append([]byte(nil), bytes...)}
runtime.SetFinalizer(k, (*PublicKeyECDH).finalize)
return k, nil
}
Expand All @@ -69,144 +84,128 @@ func NewPrivateKeyECDH(curve string, bytes []byte) (*PrivateKeyECDH, error) {
if err != nil {
return nil, err
}
b := bytesToBN(bytes)
if b == nil {
return nil, newOpenSSLError("BN_bin2bn")
}
defer C.go_openssl_BN_free(b)
key := C.go_openssl_EC_KEY_new_by_curve_name(nid)
if key == nil {
return nil, newOpenSSLError("EC_KEY_new_by_curve_name")
}
b := bytesToBN(bytes)
ok := b != nil && C.go_openssl_EC_KEY_set_private_key(key, b) != 0
if b != nil {
C.go_openssl_BN_free(b)
}
if !ok {
C.go_openssl_EC_KEY_free(key)
var pkey C.GO_EVP_PKEY_PTR
defer func() {
if pkey == nil {
C.go_openssl_EC_KEY_free(key)
}
}()
if C.go_openssl_EC_KEY_set_private_key(key, b) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key")
}
k := &PrivateKeyECDH{curve, key}
pkey, err = newEVPPKEY(key)
if err != nil {
return nil, err
}
k := &PrivateKeyECDH{pkey}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, nil
}

func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
defer runtime.KeepAlive(k)

group := C.go_openssl_EC_KEY_get0_group(k.key)
key := C.go_openssl_EVP_PKEY_get1_EC_KEY(k._pkey)
if key == nil {
return nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY")
}
defer C.go_openssl_EC_KEY_free(key)
group := C.go_openssl_EC_KEY_get0_group(key)
if group == nil {
return nil, newOpenSSLError("EC_KEY_get0_group")
}
kbig := C.go_openssl_EC_KEY_get0_private_key(k.key)
if kbig == nil {
return nil, newOpenSSLError("EC_KEY_get0_private_key")
}
pt := C.go_openssl_EC_POINT_new(group)
pt := C.go_openssl_EC_KEY_get0_public_key(key)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 {
C.go_openssl_EC_POINT_free(pt)
return nil, newOpenSSLError("EC_POINT_mul")
// The public key will be nil if k has been generated using
// NewPrivateKeyECDH instead of GenerateKeyECDH.
//
// OpenSSL does not expose any method to generate the public
// key from the private key [1], so we have to calculate it here
// https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
pt = C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
defer C.go_openssl_EC_POINT_free(pt)
kbig := C.go_openssl_EC_KEY_get0_private_key(key)
if C.go_openssl_EC_POINT_mul(group, pt, kbig, nil, nil, nil) == 0 {
return nil, newOpenSSLError("EC_POINT_mul")
}
}
n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, nil, 0, nil)
if n == 0 {
return nil, newOpenSSLError("EC_POINT_point2oct")
}
bytes, err := pointBytesECDH(k.curve, group, pt)
if err != nil {
C.go_openssl_EC_POINT_free(pt)
return nil, err
bytes := make([]byte, n)
n = C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(bytes), C.size_t(len(bytes)), nil)
if int(n) != len(bytes) {
return nil, newOpenSSLError("EC_POINT_point2oct")
}
pub := &PublicKeyECDH{k.curve, pt, bytes}
pub := &PublicKeyECDH{k._pkey, bytes}
// Note: Same as in NewPublicKeyECDH regarding finalizer and KeepAlive.
runtime.SetFinalizer(pub, (*PublicKeyECDH).finalize)
return pub, nil
}

func pointBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) {
out := make([]byte, 1+2*curveSize(curve))
n := C.go_openssl_EC_POINT_point2oct(group, pt, C.GO_POINT_CONVERSION_UNCOMPRESSED, base(out), C.size_t(len(out)), nil)
if int(n) != len(out) {
return nil, newOpenSSLError("EC_POINT_point2oct")
}
return out, nil
}

func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
group := C.go_openssl_EC_KEY_get0_group(priv.key)
if group == nil {
return nil, newOpenSSLError("EC_KEY_get0_group")
defer runtime.KeepAlive(priv)
defer runtime.KeepAlive(pub)
ctx := C.go_openssl_EVP_PKEY_CTX_new(priv._pkey, nil)
if ctx == nil {
return nil, newOpenSSLError("EVP_PKEY_CTX_new")
}
privBig := C.go_openssl_EC_KEY_get0_private_key(priv.key)
if privBig == nil {
return nil, newOpenSSLError("EC_KEY_get0_private_key")
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
pt := C.go_openssl_EC_POINT_new(group)
if pt == nil {
return nil, newOpenSSLError("EC_POINT_new")
}
defer C.go_openssl_EC_POINT_free(pt)
if C.go_openssl_EC_POINT_mul(group, pt, nil, pub.key, privBig, nil) == 0 {
return nil, newOpenSSLError("EC_POINT_mul")
}
out, err := xCoordBytesECDH(priv.curve, group, pt)
if err != nil {
return nil, err
if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_set_peer")
}
return out, nil
}

func xCoordBytesECDH(curve string, group C.GO_EC_GROUP_PTR, pt C.GO_EC_POINT_PTR) ([]byte, error) {
big := C.go_openssl_BN_new()
defer C.go_openssl_BN_free(big)
if C.go_openssl_EC_POINT_get_affine_coordinates_GFp(group, pt, big, nil, nil) == 0 {
return nil, newOpenSSLError("EC_POINT_get_affine_coordinates_GFp")
var outLen C.size_t
if C.go_openssl_EVP_PKEY_derive(ctx, nil, &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
return bigBytesECDH(curve, big)
}

func bigBytesECDH(curve string, big C.GO_BIGNUM_PTR) ([]byte, error) {
out := make([]byte, curveSize(curve))
if C.go_openssl_BN_bn2binpad(big, base(out), C.int(len(out))) == 0 {
return nil, newOpenSSLError("BN_bn2binpad")
out := make([]byte, outLen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
return out, nil
}

func curveSize(curve string) int {
switch curve {
default:
panic("openssl: unknown curve " + curve)
case "P-256":
return 256 / 8
case "P-384":
return 384 / 8
case "P-521":
return (521 + 7) / 8
}
}

func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
pkey, err := generateEVPPKey(C.GO_EVP_PKEY_EC, 0, curve)
if err != nil {
return nil, nil, err
}
defer C.go_openssl_EVP_PKEY_free(pkey)
var k *PrivateKeyECDH
defer func() {
if k == nil {
C.go_openssl_EVP_PKEY_free(pkey)
}
}()
key := C.go_openssl_EVP_PKEY_get1_EC_KEY(pkey)
if key == nil {
return nil, nil, newOpenSSLError("EVP_PKEY_get1_EC_KEY")
}
group := C.go_openssl_EC_KEY_get0_group(key)
if group == nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, newOpenSSLError("EC_KEY_get0_group")
}
defer C.go_openssl_EC_KEY_free(key)
b := C.go_openssl_EC_KEY_get0_private_key(key)
if b == nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, newOpenSSLError("EC_KEY_get0_private_key")
}
bytes, err := bigBytesECDH(curve, b)
if err != nil {
C.go_openssl_EC_KEY_free(key)
return nil, nil, err
bits := C.go_openssl_EVP_PKEY_get_bits(pkey)
out := make([]byte, (bits+7)/8)
if C.go_openssl_BN_bn2binpad(b, base(out), C.int(len(out))) == 0 {
return nil, nil, newOpenSSLError("BN_bn2binpad")
}

k := &PrivateKeyECDH{curve, key}
k = &PrivateKeyECDH{pkey}
runtime.SetFinalizer(k, (*PrivateKeyECDH).finalize)
return k, bytes, nil
return k, out, nil
}
53 changes: 21 additions & 32 deletions openssl/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import "C"
import (
"errors"
"runtime"
"unsafe"
)

type PrivateKeyECDSA struct {
Expand Down Expand Up @@ -73,58 +72,48 @@ func NewPublicKeyECDSA(curve string, X, Y BigInt) (*PublicKeyECDSA, error) {
return k, nil
}

func newECKey(curve string, X, Y, D BigInt) (pkey C.GO_EVP_PKEY_PTR, err error) {
var nid C.int
if nid, err = curveNID(curve); err != nil {
func newECKey(curve string, X, Y, D BigInt) (C.GO_EVP_PKEY_PTR, error) {
nid, err := curveNID(curve)
if err != nil {
return nil, err
}
var bx, by C.GO_BIGNUM_PTR
var key C.GO_EC_KEY_PTR
var bx, by, bd C.GO_BIGNUM_PTR
defer func() {
if bx != nil {
C.go_openssl_BN_free(bx)
}
if by != nil {
C.go_openssl_BN_free(by)
}
if err != nil {
if key != nil {
C.go_openssl_EC_KEY_free(key)
}
if pkey != nil {
C.go_openssl_EVP_PKEY_free(pkey)
// pkey is a named return, so in case of error
// it have to be cleared before returing.
pkey = nil
}
if bd != nil {
C.go_openssl_BN_free(bd)
}
}()
bx = bigToBN(X)
by = bigToBN(Y)
if bx == nil || by == nil {
bd = bigToBN(D)
if bx == nil || by == nil || (D != nil && bd == nil) {
return nil, newOpenSSLError("BN_lebin2bn failed")
}
if key = C.go_openssl_EC_KEY_new_by_curve_name(nid); key == nil {
key := C.go_openssl_EC_KEY_new_by_curve_name(nid)
if key == nil {
return nil, newOpenSSLError("EC_KEY_new_by_curve_name failed")
}
var pkey C.GO_EVP_PKEY_PTR
defer func() {
if pkey == nil {
defer C.go_openssl_EC_KEY_free(key)
}
}()
if C.go_openssl_EC_KEY_set_public_key_affine_coordinates(key, bx, by) != 1 {
return nil, newOpenSSLError("EC_KEY_set_public_key_affine_coordinates failed")
}
if D != nil {
bd := bigToBN(D)
if bd == nil {
return nil, newOpenSSLError("BN_lebin2bn failed")
}
defer C.go_openssl_BN_free(bd)
if C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key failed")
}
}
if pkey = C.go_openssl_EVP_PKEY_new(); pkey == nil {
return nil, newOpenSSLError("EVP_PKEY_new failed")
if D != nil && C.go_openssl_EC_KEY_set_private_key(key, bd) != 1 {
return nil, newOpenSSLError("EC_KEY_set_private_key failed")
}
if C.go_openssl_EVP_PKEY_assign(pkey, C.GO_EVP_PKEY_EC, (unsafe.Pointer)(key)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_assign failed")
pkey, err = newEVPPKEY(key)
if err != nil {
return nil, err
}
return pkey, nil
}
Expand Down
Loading

0 comments on commit 5eef080

Please sign in to comment.