diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 39ff80207..4f3281259 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -189,7 +189,12 @@ jobs: - run: grpcurl -plaintext localhost:8080 list - run: grpcurl -plaintext localhost:8080 grpc.health.v1.Health.Check - run: grpcurl -plaintext localhost:8080 kas.AccessService/PublicKey - - run: curl --show-error --fail --insecure localhost:8080/kas/v2/kas_public_key + - run: curl --show-error --fail-with-body --insecure localhost:8080/kas/v2/kas_public_key + - run: curl --show-error --fail-with-body --insecure localhost:8080/kas/v2/kas_public_key?algorithm=ec:secp256r1 + - run: |- + curl_status=$(curl -o /dev/null -s -w "%{http_code}" localhost:8080/kas/v2/kas_public_key?algorithm=invalid) + [ $curl_status = 404 ] + - run: grpcurl -d '{"algorithm":"invalid"}' -plaintext localhost:8080 kas.AccessService/PublicKey 2>&1 | grep NotFound - run: go run ./examples encrypt "Hello Virtru" - run: go run ./examples decrypt sensitive.txt.tdf - run: go run ./examples decrypt sensitive.txt.tdf | grep "Hello Virtru" diff --git a/service/kas/access/publicKey.go b/service/kas/access/publicKey.go index 2672381b9..e03851092 100644 --- a/service/kas/access/publicKey.go +++ b/service/kas/access/publicKey.go @@ -10,6 +10,7 @@ import ( "log/slog" kaspb "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/service/internal/security" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" @@ -19,78 +20,107 @@ const ( ErrCertificateEncode = Error("certificate encode error") ErrPublicKeyMarshal = Error("public key marshal error") algorithmEc256 = "ec:secp256r1" + algorithmRSA2048 = "rsa:2048" ) -func (p *Provider) LegacyPublicKey(ctx context.Context, in *kaspb.LegacyPublicKeyRequest) (*wrapperspb.StringValue, error) { +func (p Provider) lookupKid(ctx context.Context, algorithm string) (string, error) { + key := "unknown" + defaultKid := "unknown" + if algorithm == algorithmEc256 { + defaultKid = "123" + key = "eccertid" + } + + if p.Config == nil || p.Config.ExtraProps == nil { + slog.WarnContext(ctx, "using default kid", "kid", defaultKid, "algorithm", algorithm, "certid", key) + return defaultKid, nil + } + + certid, ok := p.Config.ExtraProps[key] + if !ok { + slog.WarnContext(ctx, "using default kid", "kid", defaultKid, "algorithm", algorithm, "certid", key) + return defaultKid, nil + } + + kid, ok := certid.(string) + if !ok { + slog.ErrorContext(ctx, "invalid key configuration", "kid", defaultKid, "algorithm", algorithm, "certid", key) + return "", ErrConfig + } + return kid, nil +} + +func (p Provider) LegacyPublicKey(ctx context.Context, in *kaspb.LegacyPublicKeyRequest) (*wrapperspb.StringValue, error) { algorithm := in.GetAlgorithm() var pem string var err error if p.CryptoProvider == nil { return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) } - if algorithm == algorithmEc256 { - ecCertIDInf := p.Config.ExtraProps["eccertid"] - ecCertID, ok := ecCertIDInf.(string) - if !ok { - return nil, errors.New("services.kas.eccertid is not a string") - } - pem, err = p.CryptoProvider.ECCertificate(ecCertID) + kid, err := p.lookupKid(ctx, algorithm) + if err != nil { + return nil, err + } + + switch algorithm { + case algorithmEc256: + pem, err = p.CryptoProvider.ECCertificate(kid) if err != nil { slog.ErrorContext(ctx, "CryptoProvider.ECPublicKey failed", "err", err) return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) } - } else { - pem, err = p.CryptoProvider.RSAPublicKey("unknown") + case algorithmRSA2048: + fallthrough + case "": + pem, err = p.CryptoProvider.RSAPublicKey(kid) if err != nil { slog.ErrorContext(ctx, "CryptoProvider.RSAPublicKey failed", "err", err) return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) } - } - if err != nil { - slog.ErrorContext(ctx, "unable to generate PEM", "err", err) - return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) + default: + return nil, errors.Join(ErrConfig, status.Error(codes.NotFound, "invalid algorithm")) } return &wrapperspb.StringValue{Value: pem}, nil } -func (p *Provider) PublicKey(ctx context.Context, in *kaspb.PublicKeyRequest) (*kaspb.PublicKeyResponse, error) { +func (p Provider) PublicKey(ctx context.Context, in *kaspb.PublicKeyRequest) (*kaspb.PublicKeyResponse, error) { algorithm := in.GetAlgorithm() - if algorithm == algorithmEc256 { - ecPublicKeyPem, err := p.CryptoProvider.ECPublicKey("123") - if err != nil { - slog.ErrorContext(ctx, "CryptoProvider.ECPublicKey failed", "err", err) - return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) - } - - return &kaspb.PublicKeyResponse{PublicKey: ecPublicKeyPem}, nil + fmt := in.GetFmt() + kid, err := p.lookupKid(ctx, algorithm) + if err != nil { + return nil, err } - if in.GetFmt() == "jwk" { - rsaPublicKeyPem, err := p.CryptoProvider.RSAPublicKeyAsJSON("unknown") - if err != nil { - slog.ErrorContext(ctx, "CryptoProvider.RSAPublicKey failed", "err", err) + r := func(k string, err error) (*kaspb.PublicKeyResponse, error) { + if errors.Is(err, security.ErrCertNotFound) { + slog.ErrorContext(ctx, "no key found for", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt) + return nil, errors.Join(err, status.Error(codes.NotFound, "no such key")) + } else if err != nil { + slog.ErrorContext(ctx, "configuration error for key lookup", "err", err, "kid", kid, "algorithm", algorithm, "fmt", fmt) return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) } - - return &kaspb.PublicKeyResponse{PublicKey: rsaPublicKeyPem}, nil + return &kaspb.PublicKeyResponse{PublicKey: k}, nil } - if in.GetFmt() == "pkcs8" { - rsaPublicKeyPem, err := p.CryptoProvider.RSAPublicKey("unknown") - if err != nil { - slog.ErrorContext(ctx, "CryptoProvider.RSAPublicKey failed", "err", err) - return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) + switch algorithm { + case algorithmEc256: + ecPublicKeyPem, err := p.CryptoProvider.ECPublicKey(kid) + return r(ecPublicKeyPem, err) + case algorithmRSA2048: + fallthrough + case "": + switch fmt { + case "jwk": + rsaPublicKeyPem, err := p.CryptoProvider.RSAPublicKeyAsJSON(kid) + return r(rsaPublicKeyPem, err) + case "pkcs8": + fallthrough + case "": + rsaPublicKeyPem, err := p.CryptoProvider.RSAPublicKey(kid) + return r(rsaPublicKeyPem, err) } - return &kaspb.PublicKeyResponse{PublicKey: rsaPublicKeyPem}, nil } - - rsaPublicKeyPem, err := p.CryptoProvider.RSAPublicKey("unknown") - if err != nil { - slog.ErrorContext(ctx, "CryptoProvider.RSAPublicKey failed", "err", err) - return nil, errors.Join(ErrConfig, status.Error(codes.Internal, "configuration error")) - } - - return &kaspb.PublicKeyResponse{PublicKey: rsaPublicKeyPem}, nil + return nil, status.Error(codes.NotFound, "invalid algorithm or format") } func exportRsaPublicKeyAsPemStr(pubkey *rsa.PublicKey) (string, error) { diff --git a/service/kas/access/publicKey_test.go b/service/kas/access/publicKey_test.go index e8c1d3fbd..30f32163c 100644 --- a/service/kas/access/publicKey_test.go +++ b/service/kas/access/publicKey_test.go @@ -18,6 +18,8 @@ import ( "github.com/opentdf/platform/service/internal/security" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // Skips if not in CI and failure due to library missing @@ -154,7 +156,7 @@ func TestStandardPublicKeyHandlerV2(t *testing.T) { CryptoProvider: c, } - result, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{Algorithm: "rsa"}) + result, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{}) require.NoError(t, err) require.NotNil(t, result) assert.Contains(t, result.GetPublicKey(), "BEGIN PUBLIC KEY") @@ -172,11 +174,41 @@ func TestStandardPublicKeyHandlerV2Failure(t *testing.T) { CryptoProvider: c, } - k, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{Algorithm: "rsa"}) + k, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{}) assert.Nil(t, k) require.Error(t, err) } +func TestStandardPublicKeyHandlerV2NotFound(t *testing.T) { + configStandard := security.Config{ + Type: "standard", + StandardConfig: security.StandardConfig{ + RSAKeys: map[string]security.StandardKeyInfo{ + "rsa": { + PrivateKeyPath: "./testdata/access-provider-000-private.pem", + PublicKeyPath: "./testdata/access-provider-000-certificate.pem", + }, + }, + }, + } + c := mustNewCryptoProvider(t, configStandard) + defer c.Close() + kasURI := urlHost(t) + kas := Provider{ + URI: *kasURI, + CryptoProvider: c, + } + + k, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{ + Algorithm: "algorithm:unknown", + }) + assert.Nil(t, k) + require.Error(t, err) + status, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.NotFound, status.Code()) +} + func TestStandardPublicKeyHandlerV2WithJwk(t *testing.T) { configStandard := security.Config{ Type: "standard", @@ -198,7 +230,7 @@ func TestStandardPublicKeyHandlerV2WithJwk(t *testing.T) { } result, err := kas.PublicKey(context.Background(), &kaspb.PublicKeyRequest{ - Algorithm: "rsa", + Algorithm: "rsa:2048", V: "2", Fmt: "jwk", })