Skip to content

Commit

Permalink
clenaups
Browse files Browse the repository at this point in the history
  • Loading branch information
dmihalcik-virtru committed Jan 6, 2025
1 parent ea74aa5 commit e929f8f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 13 deletions.
4 changes: 2 additions & 2 deletions service/kas/access/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (p *Provider) LoadStandardCryptoProvider() (*recrypt.Standard, error) {
for _, key := range p.Keyring {
privatePemData, err := os.ReadFile(key.Private)
if err != nil {
return nil, fmt.Errorf("failed to rsa private key file: %w", err)
return nil, fmt.Errorf("failure to read rsa private key file [%s]: %w", key.Private, err)
}

var secret crypto.PrivateKey
Expand All @@ -65,7 +65,7 @@ func (p *Provider) LoadStandardCryptoProvider() (*recrypt.Standard, error) {
if key.Certificate != "" {
publicPemData, err = os.ReadFile(key.Certificate)
if err != nil {
return nil, fmt.Errorf("failed to rsa public key file: %w", err)
return nil, fmt.Errorf("failure to read rsa public key file [%s]: %w", key.Certificate, err)
}
}
opts = append(opts, recrypt.WithKey(key.KID, key.Algorithm, secret, publicPemData, key.Active, key.Legacy))
Expand Down
9 changes: 6 additions & 3 deletions service/kas/access/publicKey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func TestStandardPublicKeyHandlerV2(t *testing.T) {
KID: "rsa",
Private: "./testdata/access-provider-000-private.pem",
Certificate: "./testdata/access-provider-000-certificate.pem",
Active: true,
},
},
},
Expand All @@ -134,7 +135,7 @@ func TestStandardPublicKeyHandlerV2(t *testing.T) {
result, err := kas.PublicKey(context.Background(), &connect.Request[kaspb.PublicKeyRequest]{Msg: &kaspb.PublicKeyRequest{}})
require.NoError(t, err)
require.NotNil(t, result)
assert.Contains(t, result.Msg.GetPublicKey(), "BEGIN CERTIFICATE")
assert.Contains(t, result.Msg.GetPublicKey(), "BEGIN PUBLIC KEY")
}

func TestStandardPublicKeyHandlerV2Failure(t *testing.T) {
Expand Down Expand Up @@ -163,6 +164,7 @@ func TestStandardPublicKeyHandlerV2NotFound(t *testing.T) {
KID: "rsa",
Private: "./testdata/access-provider-000-private.pem",
Certificate: "./testdata/access-provider-000-certificate.pem",
Active: true,
},
},
},
Expand Down Expand Up @@ -194,6 +196,7 @@ func TestStandardPublicKeyHandlerV2WithJwk(t *testing.T) {
KID: "rsa",
Private: "./testdata/access-provider-000-private.pem",
Certificate: "./testdata/access-provider-000-certificate.pem",
Active: true,
},
},
},
Expand Down Expand Up @@ -244,7 +247,6 @@ func TestStandardCertificateHandlerWithEc256(t *testing.T) {
}

func TestStandardPublicKeyHandlerWithEc256(t *testing.T) {
t.Skip("EC Not yet implemented")
kasURI := urlHost(t)
kas := Provider{
URI: *kasURI,
Expand All @@ -256,6 +258,7 @@ func TestStandardPublicKeyHandlerWithEc256(t *testing.T) {
KID: "rsa",
Private: "./testdata/access-provider-ec-private.pem",
Certificate: "./testdata/access-provider-ec-certificate.pem",
Active: true,
},
},
},
Expand All @@ -274,7 +277,6 @@ func TestStandardPublicKeyHandlerWithEc256(t *testing.T) {
}

func TestStandardPublicKeyHandlerV2WithEc256(t *testing.T) {
t.Skip("EC Not yet implemented")
kasURI := urlHost(t)
kas := Provider{
URI: *kasURI,
Expand All @@ -286,6 +288,7 @@ func TestStandardPublicKeyHandlerV2WithEc256(t *testing.T) {
KID: "rsa",
Private: "./testdata/access-provider-ec-private.pem",
Certificate: "./testdata/access-provider-ec-certificate.pem",
Active: true,
},
},
},
Expand Down
39 changes: 31 additions & 8 deletions service/kas/recrypt/standard.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"math/big"
Expand All @@ -29,6 +28,7 @@ type keyHolder struct {
Algorithm
KeyIdentifier
crypto.PrivateKey
certPEM []byte
publicPEM []byte
}

Expand All @@ -51,13 +51,13 @@ func NewStandard() *Standard {
type StandardOption func(*Standard) error

// WithKey adds the given key by type and id.
func WithKey(id KeyIdentifier, alg Algorithm, privateKey crypto.PrivateKey, publicPEM []byte, isCurrent, checkForLegacy bool) StandardOption {
func WithKey(id KeyIdentifier, alg Algorithm, privateKey crypto.PrivateKey, certPEM []byte, isCurrent, checkForLegacy bool) StandardOption {
return func(s *Standard) error {
s.keys[id] = keyHolder{
Algorithm: alg,
KeyIdentifier: id,
PrivateKey: privateKey,
publicPEM: publicPEM,
certPEM: certPEM,
}
if isCurrent {
s.currentKIDsByAlg[alg] = append(s.currentKIDsByAlg[alg], id)
Expand Down Expand Up @@ -149,11 +149,20 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin
if !ok {
return "", fmt.Errorf("key not found [%s]", kid)
}
jwk, err := jwk.FromRaw(holder.publicPEM)
var j jwk.Key
var err error
switch secret := holder.PrivateKey.(type) {
case *ecdsa.PrivateKey:
j, err = jwk.FromRaw(secret.Public())
case *rsa.PrivateKey:
j, err = jwk.FromRaw(secret.Public())
default:
return "", fmt.Errorf("invalid algorithm [%s] or format [%s]", a, f)
}
if err != nil {
return "", fmt.Errorf("jwk.FromRaw failed for key [%s]: %w", kid, err)
}
if err := jwks.AddKey(jwk); err != nil {
if err := jwks.AddKey(j); err != nil {
return "", fmt.Errorf("jwk.AddKey failed for key [%s]: %w", kid, err)
}
}
Expand All @@ -174,6 +183,20 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin
return string(holder.publicPEM), nil
}
switch secret := holder.PrivateKey.(type) {
case *ecdh.PrivateKey:
publicKeyBytes, err := x509.MarshalPKIXPublicKey(secret.PublicKey())
if err != nil {
return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err)
}

holder.publicPEM = pem.EncodeToMemory(
&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
},
)
return string(holder.publicPEM), nil

case *ecdsa.PrivateKey:
publicPEM, err := ocrypto.ECPublicKeyInPemFormat(secret.PublicKey)
if err != nil {
Expand All @@ -182,7 +205,7 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin
holder.publicPEM = []byte(publicPEM)
return publicPEM, nil
case *rsa.PrivateKey:
publicKeyBytes, err := x509.MarshalPKIXPublicKey(secret.PublicKey)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&secret.PublicKey)
if err != nil {
return "", fmt.Errorf("x509.MarshalPKIXPublicKey failed: %w", err)
}
Expand All @@ -196,9 +219,9 @@ func (s *Standard) PublicKey(a Algorithm, k []KeyIdentifier, f KeyFormat) (strin
holder.publicPEM = publicPEM
return string(publicPEM), nil
}
return "", errors.New("invalid algorithm or format")
return "", fmt.Errorf("invalid algorithm [%T] or format [%s]", holder.PrivateKey, f)
}
return "", errors.New("invalid algorithm or format")
return "", fmt.Errorf("invalid format [%s]", f)
}

func (s *Standard) Unwrap(k KeyIdentifier, ciphertext []byte) (UnwrappedKey, error) {
Expand Down

0 comments on commit e929f8f

Please sign in to comment.