diff --git a/src/ec.c b/src/ec.c index 075b894f..fae38409 100644 --- a/src/ec.c +++ b/src/ec.c @@ -640,10 +640,13 @@ static int openssl_ec_key_parse(lua_State*L) BIGNUM* x = BN_new(); BIGNUM* y = BN_new(); - priv = BN_dup(priv); - AUXILIAR_SETOBJECT(L, priv, "openssl.bn", -1, "d"); + if (priv != NULL) + { + priv = BN_dup(priv); + AUXILIAR_SETOBJECT(L, priv, "openssl.bn", -1, "d"); + } - if (EC_POINT_get_affine_coordinates(group, point, x, y, NULL) == 1) + if (point && EC_POINT_get_affine_coordinates(group, point, x, y, NULL) == 1) { AUXILIAR_SETOBJECT(L, x, "openssl.bn", -1, "x"); AUXILIAR_SETOBJECT(L, y, "openssl.bn", -1, "y"); diff --git a/src/pkey.c b/src/pkey.c index 7bcf3220..f8ff8f54 100644 --- a/src/pkey.c +++ b/src/pkey.c @@ -467,6 +467,9 @@ static LUA_FUNCTION(openssl_pkey_new) EC_GROUP_free(group); if (EC_KEY_generate_key(ec)) { +#if OPENSSL_VERSION_NUMBER > 0x30000000L + EC_KEY_generate_key_part(ec); +#endif pkey = EVP_PKEY_new(); EVP_PKEY_assign_EC_KEY(pkey, ec); } @@ -1241,22 +1244,35 @@ static LUA_FUNCTION(openssl_pkey_get_public) { EVP_PKEY *pkey = CHECK_OBJECT(1, EVP_PKEY, "openssl.evp_pkey"); int ret = 0; - size_t len = i2d_PUBKEY(pkey, NULL); + size_t len; + +#if OPENSSL_VERSION_NUMBER > 0x30000000L && !defined(LIBRESSL_VERSION_NUMBER) + if(EVP_PKEY_id(pkey) == EVP_PKEY_SM2) + { + /* NOTES: bugs in openssl3 for SM2, ugly hack */ + EC_KEY *ec = EVP_PKEY_get0_EC_KEY(pkey); + ec = EC_KEY_dup(ec); + EC_KEY_set_private_key(ec, NULL); + pkey = EVP_PKEY_new(); + EVP_PKEY_assign_EC_KEY(pkey, ec); + PUSH_OBJECT(pkey, "openssl.evp_pkey"); + return 1; + } +#endif + + len = i2d_PUBKEY(pkey, NULL); + if (len > 0) { unsigned char *buf = OPENSSL_malloc(len); - EVP_PKEY *pub = EVP_PKEY_new(); -#if OPENSSL_VERSION_NUMBER > 0x30000000L && !defined(LIBRESSL_VERSION_NUMBER) - if (pub != NULL && EVP_PKEY_copy_parameters(pub, pkey) <= 0) - { - EVP_PKEY_free(pub); - pub = NULL; - } -#endif if (buf != NULL) { unsigned char *p = buf; + EVP_PKEY *pub = EVP_PKEY_new(); +#if OPENSSL_VERSION_NUMBER > 0x30000000L && !defined(LIBRESSL_VERSION_NUMBER) + EVP_PKEY_copy_parameters(pub, pkey); +#endif len = i2d_PUBKEY(pkey, &p); p = buf; pub = d2i_PUBKEY(&pub, (const unsigned char **)&p, len); @@ -1599,11 +1615,7 @@ static LUA_FUNCTION(openssl_sign) size_t idlen = 0; const char* userId = luaL_optlstring (L, 4, SM2_DEFAULT_USERID, &idlen); -#if OPENSSL_VERSION_NUMBER > 0x30000000 - pctx = EVP_PKEY_CTX_new_from_name(NULL, "sm2", NULL); -#else pctx = EVP_PKEY_CTX_new(pkey, NULL); -#endif EVP_PKEY_CTX_set1_id(pctx, userId, idlen); EVP_MD_CTX_set_pkey_ctx(ctx, pctx); } @@ -1686,11 +1698,7 @@ static LUA_FUNCTION(openssl_verify) const char* userId = luaL_optlstring (L, 5, SM2_DEFAULT_USERID, &idlen); -#if OPENSSL_VERSION_NUMBER > 0x30000000 - pctx = EVP_PKEY_CTX_new_from_name(NULL, "sm2", NULL); -#else pctx = EVP_PKEY_CTX_new(pkey, NULL); -#endif EVP_PKEY_CTX_set1_id(pctx, userId, idlen); EVP_MD_CTX_set_pkey_ctx(ctx, pctx); } diff --git a/test/sm2.lua b/test/sm2.lua index 0e34046d..0ed3f6a7 100644 --- a/test/sm2.lua +++ b/test/sm2.lua @@ -5,114 +5,113 @@ local unpack = unpack or table.unpack local helper = require'helper' local _,_,opensslv = openssl.version(true) -if opensslv >= 0x10101007 and (not helper.libressl) then - if helper.openssl3 then --FIXME: get public key, sign, verify - print('Support SM2, but bugs, skip') - return - else - print('Support SM2') - end - - testSM2 = {} +if opensslv <= 0x10101007 or helper.libressl then + print('SKIP SM2') + return +end - function testSM2:testSM2() - local nec = {'ec','SM2'} - local ec = pkey.new(unpack(nec)) - local t = ec:parse() - if helper.openssl3 then - lu.assertEquals(t.type, 'SM2') - t = t.sm2:parse(true) --make basic table - else - lu.assertEquals(t.type, 'EC') - t = t.ec:parse(true) --make basic table - end - lu.assertEquals(type(t.curve_name), 'number') - lu.assertStrContains(t.x.version, 'bn library') - lu.assertStrContains(t.y.version, 'bn library') - lu.assertStrContains(t.d.version, 'bn library') +TestSM2 = {} - local k1 = pkey.get_public(ec) - assert(not k1:is_private()) - t = k1:parse() - assert(t.bits==256) - assert(t.type=='EC') - assert(t.size==72) - local r = t.ec - t = r:parse(true) --make basic table - lu.assertEquals(type(t.curve_name), 'number') - lu.assertStrContains(t.x.version, 'bn library') - lu.assertStrContains(t.y.version, 'bn library') - lu.assertEquals(t.d, nil) - t = r:parse() - lu.assertStrContains(tostring(t.pub_key), 'openssl.ec_point') - lu.assertStrContains(tostring(t.group), 'openssl.ec_group') - local x, y = t.group:affine_coordinates(t.pub_key) - lu.assertStrContains(x.version, 'bn library') - lu.assertStrContains(y.version, 'bn library') - local ec2p = { - alg = 'ec', - ec_name = t.group:parse().curve_name, - x = x, - y = y, - } - local ec2 = pkey.new(ec2p) - assert(not ec2:is_private()) +function TestSM2:TestSM2() + local nec = {'ec','SM2'} - ec2p.d = ec:parse().ec:parse().priv_key - local ec2priv = pkey.new(ec2p) - assert(ec2priv:is_private()) + local ec = pkey.new(unpack(nec)) + local t = ec:parse() + if helper.openssl3 then + lu.assertEquals(t.type, 'SM2') + t = t.sm2:parse(true) --make basic table + else + lu.assertEquals(t.type, 'EC') + t = t.ec:parse(true) --make basic table + end + lu.assertEquals(type(t.curve_name), 'number') + lu.assertStrContains(t.x.version, 'bn library') + lu.assertStrContains(t.y.version, 'bn library') + lu.assertStrContains(t.d.version, 'bn library') - nec = {'ec','SM2'} - local key1 = pkey.new(unpack(nec)) - local key2 = pkey.new(unpack(nec)) - local ec1 = key1:parse().ec - ec2 = key2:parse().ec - local secret1 = ec1:compute_key(ec2) - local secret2 = ec2:compute_key(ec1) - assert(secret1==secret2) + local k1 = pkey.get_public(ec) + assert(not k1:is_private()) + t = k1:parse() + assert(not k1:missing_paramaters()) + assert(t.bits==256) + assert(t.type==(helper.openssl3 and 'SM2' or 'EC'), t.type) + assert(t.size==72) + local r = helper.openssl3 and t.sm2 or t.ec + t = r:parse(true) --make basic table + lu.assertEquals(type(t.curve_name), 'number') + lu.assertStrContains(t.x.version, 'bn library') + lu.assertStrContains(t.y.version, 'bn library') + lu.assertEquals(t.d, nil) + t = r:parse() + lu.assertStrContains(tostring(t.pub_key), 'openssl.ec_point') + lu.assertStrContains(tostring(t.group), 'openssl.ec_group') + local x, y = t.group:affine_coordinates(t.pub_key) + lu.assertStrContains(x.version, 'bn library') + lu.assertStrContains(y.version, 'bn library') + local ec2p = { + alg = 'ec', + ec_name = t.group:parse().curve_name, + x = x, + y = y, + } + local ec2 = pkey.new(ec2p) + assert(not ec2:is_private()) - local pub1 = pkey.get_public(key1) - local pub2 = pkey.get_public(key2) - pub1 = pub1:parse().ec - pub2 = pub2:parse().ec + t = ec:parse() + t = helper.openssl3 and t.sm2 or t.ec + ec2p.d = t:parse().priv_key + local ec2priv = pkey.new(ec2p) + assert(ec2priv:is_private()) - secret1 = ec1:compute_key(pub2) - secret2 = ec2:compute_key(pub1) - assert(secret1==secret2) - end + nec = {'ec','SM2'} + local key1 = pkey.new(unpack(nec)) + local key2 = pkey.new(unpack(nec)) + local ec1 + if helper.openssl3 then + ec1 = key1:parse().sm2 + ec2 = key2:parse().sm2 + else + ec1 = key1:parse().ec + ec2 = key2:parse().ec + end + local secret1 = ec1:compute_key(ec2) + local secret2 = ec2:compute_key(ec1) + assert(secret1==secret2) - function testSM2:testEC_SignVerify() - local nec = {'ec','SM2'} - local pri = pkey.new(unpack(nec)) - local pub = pri:get_public() - local msg = openssl.random(32) - local sig + local pub1 = pkey.get_public(key1) + local pub2 = pkey.get_public(key2) + if helper.openssl3 then + pub1 = pub1:parse().sm2 + pub2 = pub2:parse().sm2 + else + pub1 = pub1:parse().ec + pub2 = pub2:parse().ec + end - -- FIXME: openssl - --if openssl.digest.get('sm3') then - -- sig = assert(pri:sign(msg, 'sm3')) - -- assert(pub:verify(msg, sig, 'sm3')) - --end + secret1 = ec1:compute_key(pub2) + secret2 = ec2:compute_key(pub1) + assert(secret1==secret2) +end - sig = assert(pri:sign(msg, 'sha256')) - assert(pub:verify(msg, sig, 'sha256')) - end +function TestSM2:TestSM2_SignVerify() + local nec = {'ec','SM2'} + local pri = pkey.new(unpack(nec)) + local pub = pri:get_public() + local msg = openssl.random(33) - function testSM2:testSM2_SignVerify() - local nec = {'ec','SM2'} - local pri = pkey.new(unpack(nec)) - local pub = pri:get_public() - local msg = openssl.random(33) + local sig + if not helper.openssl3 then + sig = assert(pri:sign(msg, 'sha256')) + assert(pub:verify(msg, sig, 'sha256')) + end - if pri.as_sm2 then - assert(pri:as_sm2()) - assert(pub:as_sm2()) - end + if pri.as_sm2 then + -- OpenSSL v1.1.1 + assert(pri:as_sm2()) + assert(pub:as_sm2()) + end - local sig = assert(pri:sign(msg, 'sm3')) - assert(pub:verify(msg, sig, 'sm3')) - end -else - print('Skip SM2') + sig = assert(pri:sign(msg, 'sm3')) + assert(pub:verify(msg, sig, 'sm3')) end