From e929f8fe7c14d39c361f88eaffa38cfba48b38e9 Mon Sep 17 00:00:00 2001 From: David Mihalcik Date: Mon, 6 Jan 2025 13:29:05 -0500 Subject: [PATCH] clenaups --- service/kas/access/provider.go | 4 +-- service/kas/access/publicKey_test.go | 9 ++++--- service/kas/recrypt/standard.go | 39 ++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 13 deletions(-) diff --git a/service/kas/access/provider.go b/service/kas/access/provider.go index ed310f448..8c0515ba1 100644 --- a/service/kas/access/provider.go +++ b/service/kas/access/provider.go @@ -42,7 +42,7 @@ func (p *Provider) LoadStandardCryptoProvider() (*recrypt.Standard, error) { for _, key := range p.Keyring { privatePemData, err := os.ReadFile(key.Private) if err != nil { - return nil, fmt.Errorf("failed to rsa private key file: %w", err) + return nil, fmt.Errorf("failure to read rsa private key file [%s]: %w", key.Private, err) } var secret crypto.PrivateKey @@ -65,7 +65,7 @@ func (p *Provider) LoadStandardCryptoProvider() (*recrypt.Standard, error) { if key.Certificate != "" { publicPemData, err = os.ReadFile(key.Certificate) if err != nil { - return nil, fmt.Errorf("failed to rsa public key file: %w", err) + return nil, fmt.Errorf("failure to read rsa public key file [%s]: %w", key.Certificate, err) } } opts = append(opts, recrypt.WithKey(key.KID, key.Algorithm, secret, publicPemData, key.Active, key.Legacy)) diff --git a/service/kas/access/publicKey_test.go b/service/kas/access/publicKey_test.go index da33ca652..e6d626855 100644 --- a/service/kas/access/publicKey_test.go +++ b/service/kas/access/publicKey_test.go @@ -124,6 +124,7 @@ func TestStandardPublicKeyHandlerV2(t *testing.T) { KID: "rsa", Private: "./testdata/access-provider-000-private.pem", Certificate: "./testdata/access-provider-000-certificate.pem", + Active: true, }, }, }, @@ -134,7 +135,7 @@ func TestStandardPublicKeyHandlerV2(t *testing.T) { result, err := kas.PublicKey(context.Background(), &connect.Request[kaspb.PublicKeyRequest]{Msg: &kaspb.PublicKeyRequest{}}) require.NoError(t, err) require.NotNil(t, result) - assert.Contains(t, result.Msg.GetPublicKey(), "BEGIN CERTIFICATE") + assert.Contains(t, result.Msg.GetPublicKey(), "BEGIN PUBLIC KEY") } func TestStandardPublicKeyHandlerV2Failure(t *testing.T) { @@ -163,6 +164,7 @@ func TestStandardPublicKeyHandlerV2NotFound(t *testing.T) { KID: "rsa", Private: "./testdata/access-provider-000-private.pem", Certificate: "./testdata/access-provider-000-certificate.pem", + Active: true, }, }, }, @@ -194,6 +196,7 @@ func TestStandardPublicKeyHandlerV2WithJwk(t *testing.T) { KID: "rsa", Private: "./testdata/access-provider-000-private.pem", Certificate: "./testdata/access-provider-000-certificate.pem", + Active: true, }, }, }, @@ -244,7 +247,6 @@ func TestStandardCertificateHandlerWithEc256(t *testing.T) { } func TestStandardPublicKeyHandlerWithEc256(t *testing.T) { - t.Skip("EC Not yet implemented") kasURI := urlHost(t) kas := Provider{ URI: *kasURI, @@ -256,6 +258,7 @@ func TestStandardPublicKeyHandlerWithEc256(t *testing.T) { KID: "rsa", Private: "./testdata/access-provider-ec-private.pem", Certificate: "./testdata/access-provider-ec-certificate.pem", + Active: true, }, }, }, @@ -274,7 +277,6 @@ func TestStandardPublicKeyHandlerWithEc256(t *testing.T) { } func TestStandardPublicKeyHandlerV2WithEc256(t *testing.T) { - t.Skip("EC Not yet implemented") kasURI := urlHost(t) kas := Provider{ URI: *kasURI, @@ -286,6 +288,7 @@ func TestStandardPublicKeyHandlerV2WithEc256(t *testing.T) { KID: "rsa", Private: "./testdata/access-provider-ec-private.pem", Certificate: "./testdata/access-provider-ec-certificate.pem", + Active: true, }, }, }, diff --git a/service/kas/recrypt/standard.go b/service/kas/recrypt/standard.go index e0e538cd2..1ce99202c 100644 --- a/service/kas/recrypt/standard.go +++ b/service/kas/recrypt/standard.go @@ -14,7 +14,6 @@ import ( "encoding/hex" "encoding/json" "encoding/pem" - "errors" "fmt" "log/slog" "math/big" @@ -29,6 +28,7 @@ type keyHolder struct { Algorithm KeyIdentifier crypto.PrivateKey + certPEM []byte publicPEM []byte } @@ -51,13 +51,13 @@ func NewStandard() *Standard { type StandardOption func(*Standard) error // WithKey adds the given key by type and id. -func WithKey(id KeyIdentifier, alg Algorithm, privateKey crypto.PrivateKey, publicPEM []byte, isCurrent, checkForLegacy bool) StandardOption { +func WithKey(id KeyIdentifier, alg Algorithm, privateKey crypto.PrivateKey, certPEM []byte, isCurrent, checkForLegacy bool) StandardOption { return func(s *Standard) error { s.keys[id] = keyHolder{ Algorithm: alg, KeyIdentifier: id, PrivateKey: privateKey, - publicPEM: publicPEM, + certPEM: certPEM, } if isCurrent { s.currentKIDsByAlg[alg] = append(s.currentKIDsByAlg[alg], id) @@ -149,11 +149,20 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin if !ok { return "", fmt.Errorf("key not found [%s]", kid) } - jwk, err := jwk.FromRaw(holder.publicPEM) + var j jwk.Key + var err error + switch secret := holder.PrivateKey.(type) { + case *ecdsa.PrivateKey: + j, err = jwk.FromRaw(secret.Public()) + case *rsa.PrivateKey: + j, err = jwk.FromRaw(secret.Public()) + default: + return "", fmt.Errorf("invalid algorithm [%s] or format [%s]", a, f) + } if err != nil { return "", fmt.Errorf("jwk.FromRaw failed for key [%s]: %w", kid, err) } - if err := jwks.AddKey(jwk); err != nil { + if err := jwks.AddKey(j); err != nil { return "", fmt.Errorf("jwk.AddKey failed for key [%s]: %w", kid, err) } } @@ -174,6 +183,20 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin return string(holder.publicPEM), nil } switch secret := holder.PrivateKey.(type) { + case *ecdh.PrivateKey: + publicKeyBytes, err := x509.MarshalPKIXPublicKey(secret.PublicKey()) + if err != nil { + return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err) + } + + holder.publicPEM = pem.EncodeToMemory( + &pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }, + ) + return string(holder.publicPEM), nil + case *ecdsa.PrivateKey: publicPEM, err := ocrypto.ECPublicKeyInPemFormat(secret.PublicKey) if err != nil { @@ -182,7 +205,7 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin holder.publicPEM = []byte(publicPEM) return publicPEM, nil case *rsa.PrivateKey: - publicKeyBytes, err := x509.MarshalPKIXPublicKey(secret.PublicKey) + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&secret.PublicKey) if err != nil { return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err) } @@ -196,9 +219,9 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin holder.publicPEM = publicPEM return string(publicPEM), nil } - return "", errors.New("invalid algorithm or format") + return "", fmt.Errorf("invalid algorithm [%T] or format [%s]", holder.PrivateKey, f) } - return "", errors.New("invalid algorithm or format") + return "", fmt.Errorf("invalid format [%s]", f) } func (s *Standard) Unwrap(k KeyIdentifier, ciphertext []byte) (UnwrappedKey, error) {