From ce289a44505e5e3be995e5049f5cbbfb1839f41b Mon Sep 17 00:00:00 2001 From: Sean Trantalis Date: Wed, 6 Nov 2024 17:29:17 -0500 Subject: [PATCH] feat: update service registry in preperation for connectrpc migration (#1715) Depends On: #1709 To transition to ConnectRPC and register service handlers effectively, we need to utilize generics. The reason for this is that the generated ConnectRPC code for creating handlers enforces strong typing: `NewXServiceHandler(svc XServiceHandler, opts ...connect.HandlerOption)` For reference, see an example here: https://github.com/opentdf/platform/blob/0ef65d410a8e1bc8b82f52b6a4f0f469a2f7f4fe/protocol/go/policy/kasregistry/kasregistryconnect/key_access_server_registry.connect.go#L190C51-L190C88 By leveraging generics, we can maintain type safety while also keeping interceptors centrally managed. This approach ensures that no individual service can inadvertently modify the interceptor list, helping maintain consistent security and functionality across all services. --------- Co-authored-by: Jake Van Vorhis <83739412+jakedoublev@users.noreply.github.com> --- service/authorization/authorization.go | 132 +++++++++-------- .../claims/claims_entity_resolution.go | 11 +- service/entityresolution/entityresolution.go | 42 +++--- .../keycloak/keycloak_entity_resolution.go | 10 +- service/go.mod | 1 + service/go.sum | 2 + service/health/health.go | 34 ++--- service/kas/kas.go | 122 ++++++++-------- service/pkg/server/options.go | 8 +- service/pkg/server/services.go | 18 +-- service/pkg/server/services_test.go | 41 +++--- service/pkg/server/start.go | 4 +- service/pkg/server/start_test.go | 21 ++- .../pkg/serviceregistry/serviceregistry.go | 133 ++++++++++++------ service/policy/attributes/attributes.go | 20 +-- .../kasregistry/key_access_server_registry.go | 22 +-- service/policy/namespaces/namespaces.go | 32 ++--- service/policy/policy.go | 24 ++-- .../resourcemapping/resource_mapping.go | 22 +-- .../policy/subjectmapping/subject_mapping.go | 22 +-- service/policy/unsafe/unsafe.go | 21 +-- .../wellknown_configuration.go | 20 +-- 22 files changed, 410 insertions(+), 352 deletions(-) diff --git a/service/authorization/authorization.go b/service/authorization/authorization.go index 049e0a96a..8031fa780 100644 --- a/service/authorization/authorization.go +++ b/service/authorization/authorization.go @@ -54,92 +54,90 @@ type CustomRego struct { Query string `mapstructure:"query" json:"query" default:"data.opentdf.entitlements.attributes"` } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - Namespace: "authorization", - ServiceDesc: &authorization.AuthorizationService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - var ( - err error - entitlementRego []byte - authZCfg = new(Config) - ) - - logger := srp.Logger - - // default ERS endpoint - as := &AuthorizationService{sdk: srp.SDK, logger: logger} - if err := srp.RegisterReadinessCheck("authorization", as.IsReady); err != nil { - logger.Error("failed to register authorization readiness check", slog.String("error", err.Error())) - } +func NewRegistration() *serviceregistry.Service[AuthorizationService] { + return &serviceregistry.Service[AuthorizationService]{ + ServiceOptions: serviceregistry.ServiceOptions[AuthorizationService]{ + Namespace: "authorization", + ServiceDesc: &authorization.AuthorizationService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*AuthorizationService, serviceregistry.HandlerServer) { + var ( + err error + entitlementRego []byte + authZCfg = new(Config) + ) - if err := defaults.Set(authZCfg); err != nil { - panic(fmt.Errorf("failed to set defaults for authorization service config: %w", err)) - } + logger := srp.Logger - // Only decode config if it exists - if srp.Config != nil { - if err := mapstructure.Decode(srp.Config, &authZCfg); err != nil { - panic(fmt.Errorf("invalid auth svc cfg [%v] %w", srp.Config, err)) + // default ERS endpoint + as := &AuthorizationService{sdk: srp.SDK, logger: logger} + if err := srp.RegisterReadinessCheck("authorization", as.IsReady); err != nil { + logger.Error("failed to register authorization readiness check", slog.String("error", err.Error())) } - } - // Validate Config - validate := validator.New(validator.WithRequiredStructEnabled()) - if err := validate.Struct(authZCfg); err != nil { - var invalidValidationError *validator.InvalidValidationError - if errors.As(err, &invalidValidationError) { - logger.Error("error validating authorization service config", slog.String("error", err.Error())) - panic(fmt.Errorf("error validating authorization service config: %w", err)) + if err := defaults.Set(authZCfg); err != nil { + panic(fmt.Errorf("failed to set defaults for authorization service config: %w", err)) + } + + // Only decode config if it exists + if srp.Config != nil { + if err := mapstructure.Decode(srp.Config, &authZCfg); err != nil { + panic(fmt.Errorf("invalid auth svc cfg [%v] %w", srp.Config, err)) + } } - var validationErrors validator.ValidationErrors - if errors.As(err, &validationErrors) { - for _, err := range validationErrors { + // Validate Config + validate := validator.New(validator.WithRequiredStructEnabled()) + if err := validate.Struct(authZCfg); err != nil { + var invalidValidationError *validator.InvalidValidationError + if errors.As(err, &invalidValidationError) { logger.Error("error validating authorization service config", slog.String("error", err.Error())) panic(fmt.Errorf("error validating authorization service config: %w", err)) } + + var validationErrors validator.ValidationErrors + if errors.As(err, &validationErrors) { + for _, err := range validationErrors { + logger.Error("error validating authorization service config", slog.String("error", err.Error())) + panic(fmt.Errorf("error validating authorization service config: %w", err)) + } + } } - } - logger.Debug("authorization service config", slog.Any("config", *authZCfg)) + logger.Debug("authorization service config", slog.Any("config", *authZCfg)) - // Build Rego PreparedEvalQuery + // Build Rego PreparedEvalQuery - // Load rego from embedded file or custom path - if authZCfg.Rego.Path != "" { - entitlementRego, err = os.ReadFile(authZCfg.Rego.Path) - if err != nil { - panic(fmt.Errorf("failed to read custom entitlements.rego file: %w", err)) - } - } else { - entitlementRego, err = policies.EntitlementsRego.ReadFile("entitlements/entitlements.rego") - if err != nil { - panic(fmt.Errorf("failed to read entitlements.rego file: %w", err)) + // Load rego from embedded file or custom path + if authZCfg.Rego.Path != "" { + entitlementRego, err = os.ReadFile(authZCfg.Rego.Path) + if err != nil { + panic(fmt.Errorf("failed to read custom entitlements.rego file: %w", err)) + } + } else { + entitlementRego, err = policies.EntitlementsRego.ReadFile("entitlements/entitlements.rego") + if err != nil { + panic(fmt.Errorf("failed to read entitlements.rego file: %w", err)) + } } - } - // Register builtin - subjectmappingbuiltin.SubjectMappingBuiltin() + // Register builtin + subjectmappingbuiltin.SubjectMappingBuiltin() - as.eval, err = rego.New( - rego.Query(authZCfg.Rego.Query), - rego.Module("entitlements.rego", string(entitlementRego)), - rego.StrictBuiltinErrors(true), - ).PrepareForEval(context.Background()) - if err != nil { - panic(fmt.Errorf("failed to prepare entitlements.rego for eval: %w", err)) - } + as.eval, err = rego.New( + rego.Query(authZCfg.Rego.Query), + rego.Module("entitlements.rego", string(entitlementRego)), + rego.StrictBuiltinErrors(true), + ).PrepareForEval(context.Background()) + if err != nil { + panic(fmt.Errorf("failed to prepare entitlements.rego for eval: %w", err)) + } - as.config = *authZCfg + as.config = *authZCfg - return as, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - authServer, okAuth := server.(authorization.AuthorizationServiceServer) - if !okAuth { - return fmt.Errorf("failed to assert server type to authorization.AuthorizationServiceServer") + return as, func(ctx context.Context, mux *runtime.ServeMux) error { + return authorization.RegisterAuthorizationServiceHandlerServer(ctx, mux, as) } - return authorization.RegisterAuthorizationServiceHandlerServer(ctx, mux, authServer) - } + }, }, } } diff --git a/service/entityresolution/claims/claims_entity_resolution.go b/service/entityresolution/claims/claims_entity_resolution.go index cde559cc2..07ca34fd4 100644 --- a/service/entityresolution/claims/claims_entity_resolution.go +++ b/service/entityresolution/claims/claims_entity_resolution.go @@ -22,10 +22,11 @@ type ClaimsEntityResolutionService struct { logger *logger.Logger } -func RegisterClaimsERS(_ serviceregistry.ServiceConfig, logger *logger.Logger) (any, serviceregistry.HandlerServer) { - return &ClaimsEntityResolutionService{logger: logger}, - func(ctx context.Context, mux *runtime.ServeMux, server any) error { - return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, server.(entityresolution.EntityResolutionServiceServer)) //nolint:forcetypeassert // allow type assert, following other services +func RegisterClaimsERS(_ serviceregistry.ServiceConfig, logger *logger.Logger) (ClaimsEntityResolutionService, serviceregistry.HandlerServer) { + claimsSVC := ClaimsEntityResolutionService{logger: logger} + return claimsSVC, + func(ctx context.Context, mux *runtime.ServeMux) error { + return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, claimsSVC) } } @@ -64,7 +65,7 @@ func EntityResolution(_ context.Context, var resolvedEntities []*entityresolution.EntityRepresentation for idx, ident := range payload { - var entityStruct = &structpb.Struct{} + entityStruct := &structpb.Struct{} switch ident.GetEntityType().(type) { case *authorization.Entity_Claims: claims := ident.GetClaims() diff --git a/service/entityresolution/entityresolution.go b/service/entityresolution/entityresolution.go index 9fd6c81e2..d66c2dc6b 100644 --- a/service/entityresolution/entityresolution.go +++ b/service/entityresolution/entityresolution.go @@ -12,25 +12,35 @@ type ERSConfig struct { Mode string `mapstructure:"mode" json:"mode"` } -const KeycloakMode = "keycloak" -const ClaimsMode = "claims" +const ( + KeycloakMode = "keycloak" + ClaimsMode = "claims" +) + +type EntityResolution struct { + entityresolution.EntityResolutionServiceServer +} -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - Namespace: "entityresolution", - ServiceDesc: &entityresolution.EntityResolutionService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - var inputConfig ERSConfig +func NewRegistration() *serviceregistry.Service[EntityResolution] { + return &serviceregistry.Service[EntityResolution]{ + ServiceOptions: serviceregistry.ServiceOptions[EntityResolution]{ + Namespace: "entityresolution", + ServiceDesc: &entityresolution.EntityResolutionService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*EntityResolution, serviceregistry.HandlerServer) { + var inputConfig ERSConfig - if err := mapstructure.Decode(srp.Config, &inputConfig); err != nil { - panic(err) - } - if inputConfig.Mode == ClaimsMode { - return claims.RegisterClaimsERS(srp.Config, srp.Logger) - } + if err := mapstructure.Decode(srp.Config, &inputConfig); err != nil { + panic(err) + } + if inputConfig.Mode == ClaimsMode { + claimsSVC, claimsHandler := claims.RegisterClaimsERS(srp.Config, srp.Logger) + return &EntityResolution{EntityResolutionServiceServer: claimsSVC}, claimsHandler + } - // Default to keyclaok ERS - return keycloak.RegisterKeycloakERS(srp.Config, srp.Logger) + // Default to keycloak ERS + kcSVC, kcHandler := keycloak.RegisterKeycloakERS(srp.Config, srp.Logger) + return &EntityResolution{EntityResolutionServiceServer: kcSVC}, kcHandler + }, }, } } diff --git a/service/entityresolution/keycloak/keycloak_entity_resolution.go b/service/entityresolution/keycloak/keycloak_entity_resolution.go index a2a9b484c..16d3ba7e2 100644 --- a/service/entityresolution/keycloak/keycloak_entity_resolution.go +++ b/service/entityresolution/keycloak/keycloak_entity_resolution.go @@ -52,16 +52,16 @@ type KeycloakConfig struct { InferID InferredIdentityConfig `mapstructure:"inferid,omitempty" json:"inferid,omitempty"` } -func RegisterKeycloakERS(config serviceregistry.ServiceConfig, logger *logger.Logger) (any, serviceregistry.HandlerServer) { +func RegisterKeycloakERS(config serviceregistry.ServiceConfig, logger *logger.Logger) (*KeycloakEntityResolutionService, serviceregistry.HandlerServer) { var inputIdpConfig KeycloakConfig if err := mapstructure.Decode(config, &inputIdpConfig); err != nil { panic(err) } logger.Debug("entity_resolution configuration", "config", inputIdpConfig) - - return &KeycloakEntityResolutionService{idpConfig: inputIdpConfig, logger: logger}, - func(ctx context.Context, mux *runtime.ServeMux, server any) error { - return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, server.(entityresolution.EntityResolutionServiceServer)) //nolint:forcetypeassert // allow type assert, following other services + keycloakSVC := &KeycloakEntityResolutionService{idpConfig: inputIdpConfig, logger: logger} + return keycloakSVC, + func(ctx context.Context, mux *runtime.ServeMux) error { + return entityresolution.RegisterEntityResolutionServiceHandlerServer(ctx, mux, keycloakSVC) } } diff --git a/service/go.mod b/service/go.mod index 51d945bca..3056b9d5f 100644 --- a/service/go.mod +++ b/service/go.mod @@ -3,6 +3,7 @@ module github.com/opentdf/platform/service go 1.22 require ( + connectrpc.com/connect v1.17.0 github.com/Masterminds/squirrel v1.5.4 github.com/Nerzal/gocloak/v13 v13.9.0 github.com/bmatcuk/doublestar v1.3.4 diff --git a/service/go.sum b/service/go.sum index 98234e600..e64f2e4af 100644 --- a/service/go.sum +++ b/service/go.sum @@ -1,5 +1,7 @@ buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1 h1:LEXWFH/xZ5oOWrC3oOtHbUyBdzRWMCPpAQmKC9v05mA= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.34.1-20240508200655-46a4cf4ba109.1/go.mod h1:XF+P8+RmfdufmIYpGUC+6bF7S+IlmHDEnCrO3OXaUAQ= +connectrpc.com/connect v1.17.0 h1:W0ZqMhtVzn9Zhn2yATuUokDLO5N+gIuBWMOnsQrfmZk= +connectrpc.com/connect v1.17.0/go.mod h1:0292hj1rnx8oFrStN7cB4jjVBeqs+Yx5yDIC2prWDO8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= diff --git a/service/health/health.go b/service/health/health.go index a8e53c4a1..053da5dff 100644 --- a/service/health/health.go +++ b/service/health/health.go @@ -12,29 +12,29 @@ import ( "google.golang.org/grpc/status" ) -var ( - serviceHealthChecks = make(map[string]func(context.Context) error) -) +var serviceHealthChecks = make(map[string]func(context.Context) error) type HealthService struct { //nolint:revive // HealthService is a valid name for this struct healthpb.UnimplementedHealthServer logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - Namespace: "health", - ServiceDesc: &healthpb.Health_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - err := srp.WellKnownConfig("health", map[string]any{ - "endpoint": "/healthz", - }) - if err != nil { - srp.Logger.Error("failed to set well-known config", slog.String("error", err.Error())) - } - return &HealthService{logger: srp.Logger}, func(_ context.Context, _ *runtime.ServeMux, _ any) error { - return nil - } +func NewRegistration() *serviceregistry.Service[HealthService] { + return &serviceregistry.Service[HealthService]{ + ServiceOptions: serviceregistry.ServiceOptions[HealthService]{ + Namespace: "health", + ServiceDesc: &healthpb.Health_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*HealthService, serviceregistry.HandlerServer) { + err := srp.WellKnownConfig("health", map[string]any{ + "endpoint": "/healthz", + }) + if err != nil { + srp.Logger.Error("failed to set well-known config", slog.String("error", err.Error())) + } + return &HealthService{logger: srp.Logger}, func(_ context.Context, _ *runtime.ServeMux) error { + return nil + } + }, }, } } diff --git a/service/kas/kas.go b/service/kas/kas.go index 8cb3c1496..280a068ec 100644 --- a/service/kas/kas.go +++ b/service/kas/kas.go @@ -15,77 +15,75 @@ import ( "github.com/opentdf/platform/service/pkg/serviceregistry" ) -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - Namespace: "kas", - ServiceDesc: &kaspb.AccessService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - // FIXME msg="mismatched key access url" keyAccessURL=http://localhost:9000 kasURL=https://:9000 - hostWithPort := srp.OTDF.HTTPServer.Addr - if strings.HasPrefix(hostWithPort, ":") { - hostWithPort = "localhost" + hostWithPort - } - kasURLString := "http://" + hostWithPort - kasURI, err := url.Parse(kasURLString) - if err != nil { - panic(fmt.Errorf("invalid kas address [%s] %w", kasURLString, err)) - } +func NewRegistration() *serviceregistry.Service[access.Provider] { + return &serviceregistry.Service[access.Provider]{ + ServiceOptions: serviceregistry.ServiceOptions[access.Provider]{ + Namespace: "kas", + ServiceDesc: &kaspb.AccessService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*access.Provider, serviceregistry.HandlerServer) { + // FIXME msg="mismatched key access url" keyAccessURL=http://localhost:9000 kasURL=https://:9000 + hostWithPort := srp.OTDF.HTTPServer.Addr + if strings.HasPrefix(hostWithPort, ":") { + hostWithPort = "localhost" + hostWithPort + } + kasURLString := "http://" + hostWithPort + kasURI, err := url.Parse(kasURLString) + if err != nil { + panic(fmt.Errorf("invalid kas address [%s] %w", kasURLString, err)) + } - var kasCfg access.KASConfig - if err := mapstructure.Decode(srp.Config, &kasCfg); err != nil { - panic(fmt.Errorf("invalid kas cfg [%v] %w", srp.Config, err)) - } + var kasCfg access.KASConfig + if err := mapstructure.Decode(srp.Config, &kasCfg); err != nil { + panic(fmt.Errorf("invalid kas cfg [%v] %w", srp.Config, err)) + } - switch { - case kasCfg.ECCertID != "" && len(kasCfg.Keyring) > 0: - panic("invalid kas cfg: please specify keyring or eccertid, not both") - case len(kasCfg.Keyring) == 0: - deprecatedOrDefault := func(kid, alg string) { - if kid == "" { - kid = srp.OTDF.CryptoProvider.FindKID(alg) + switch { + case kasCfg.ECCertID != "" && len(kasCfg.Keyring) > 0: + panic("invalid kas cfg: please specify keyring or eccertid, not both") + case len(kasCfg.Keyring) == 0: + deprecatedOrDefault := func(kid, alg string) { + if kid == "" { + kid = srp.OTDF.CryptoProvider.FindKID(alg) + } + if kid == "" { + srp.Logger.Warn("no known key for alg", "algorithm", alg) + return + } + kasCfg.Keyring = append(kasCfg.Keyring, access.CurrentKeyFor{ + Algorithm: alg, + KID: kid, + }) + kasCfg.Keyring = append(kasCfg.Keyring, access.CurrentKeyFor{ + Algorithm: alg, + KID: kid, + Legacy: true, + }) } - if kid == "" { - srp.Logger.Warn("no known key for alg", "algorithm", alg) - return - } - kasCfg.Keyring = append(kasCfg.Keyring, access.CurrentKeyFor{ - Algorithm: alg, - KID: kid, - }) - kasCfg.Keyring = append(kasCfg.Keyring, access.CurrentKeyFor{ - Algorithm: alg, - KID: kid, - Legacy: true, - }) + deprecatedOrDefault(kasCfg.ECCertID, security.AlgorithmECP256R1) + deprecatedOrDefault(kasCfg.RSACertID, security.AlgorithmRSA2048) + default: + kasCfg.Keyring = append(kasCfg.Keyring, inferLegacyKeys(kasCfg.Keyring)...) } - deprecatedOrDefault(kasCfg.ECCertID, security.AlgorithmECP256R1) - deprecatedOrDefault(kasCfg.RSACertID, security.AlgorithmRSA2048) - default: - kasCfg.Keyring = append(kasCfg.Keyring, inferLegacyKeys(kasCfg.Keyring)...) - } - p := access.Provider{ - URI: *kasURI, - AttributeSvc: nil, - CryptoProvider: srp.OTDF.CryptoProvider, - SDK: srp.SDK, - Logger: srp.Logger, - KASConfig: kasCfg, - } + p := access.Provider{ + URI: *kasURI, + AttributeSvc: nil, + CryptoProvider: srp.OTDF.CryptoProvider, + SDK: srp.SDK, + Logger: srp.Logger, + KASConfig: kasCfg, + } - srp.Logger.Debug("kas config", "config", kasCfg) + srp.Logger.Debug("kas config", "config", kasCfg) - if err := srp.RegisterReadinessCheck("kas", p.IsReady); err != nil { - srp.Logger.Error("failed to register kas readiness check", slog.String("error", err.Error())) - } + if err := srp.RegisterReadinessCheck("kas", p.IsReady); err != nil { + srp.Logger.Error("failed to register kas readiness check", slog.String("error", err.Error())) + } - return &p, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - kas, ok := server.(*access.Provider) - if !ok { - panic("invalid kas server object") + return &p, func(ctx context.Context, mux *runtime.ServeMux) error { + return kaspb.RegisterAccessServiceHandlerServer(ctx, mux, &p) } - return kaspb.RegisterAccessServiceHandlerServer(ctx, mux, kas) - } + }, }, } } diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index 75eed19ee..6a44e268f 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -12,8 +12,8 @@ type StartConfig struct { WaitForShutdownSignal bool PublicRoutes []string authzDefaultPolicyExtension [][]string - extraCoreServices []serviceregistry.Registration - extraServices []serviceregistry.Registration + extraCoreServices []serviceregistry.IService + extraServices []serviceregistry.IService } // Deprecated: Use WithConfigKey @@ -73,7 +73,7 @@ func WithAuthZDefaultPolicyExtension(policies [][]string) StartOptions { // WithCoreServices option adds additional core services to the platform // It takes a variadic parameter of type serviceregistry.Registration, which represents the core services to be added. -func WithCoreServices(services ...serviceregistry.Registration) StartOptions { +func WithCoreServices(services ...serviceregistry.IService) StartOptions { return func(c StartConfig) StartConfig { c.extraCoreServices = append(c.extraCoreServices, services...) return c @@ -83,7 +83,7 @@ func WithCoreServices(services ...serviceregistry.Registration) StartOptions { // WithServices option adds additional services to the platform. // This will set the mode for these services to the namespace name. // It takes a variadic parameter of type serviceregistry.Registration, which represents the services to be added. -func WithServices(services ...serviceregistry.Registration) StartOptions { +func WithServices(services ...serviceregistry.IService) StartOptions { return func(c StartConfig) StartConfig { c.extraServices = append(c.extraServices, services...) return c diff --git a/service/pkg/server/services.go b/service/pkg/server/services.go index 6941609a8..14d4e8f60 100644 --- a/service/pkg/server/services.go +++ b/service/pkg/server/services.go @@ -39,7 +39,7 @@ const ( // registerEssentialServices registers the essential services to the given service registry. // It takes a serviceregistry.Registry as input and returns an error if registration fails. func registerEssentialServices(reg serviceregistry.Registry) error { - essentialServices := []serviceregistry.Registration{ + essentialServices := []serviceregistry.IService{ health.NewRegistration(), } // Register the essential services @@ -55,7 +55,7 @@ func registerEssentialServices(reg serviceregistry.Registry) error { // It returns the list of registered services and any error encountered during registration. func registerCoreServices(reg serviceregistry.Registry, mode []string) ([]string, error) { var ( - services []serviceregistry.Registration + services []serviceregistry.IService registeredServices []string ) @@ -63,7 +63,7 @@ func registerCoreServices(reg serviceregistry.Registry, mode []string) ([]string switch m { case "all": registeredServices = append(registeredServices, []string{servicePolicy, serviceAuthorization, serviceKAS, serviceWellKnown, serviceEntityResolution}...) - services = append(services, []serviceregistry.Registration{ + services = append(services, []serviceregistry.IService{ authorization.NewRegistration(), kas.NewRegistration(), wellknown.NewRegistration(), @@ -72,7 +72,7 @@ func registerCoreServices(reg serviceregistry.Registry, mode []string) ([]string services = append(services, policy.NewRegistrations()...) case "core": registeredServices = append(registeredServices, []string{servicePolicy, serviceAuthorization, serviceWellKnown}...) - services = append(services, []serviceregistry.Registration{ + services = append(services, []serviceregistry.IService{ entityresolution.NewRegistration(), authorization.NewRegistration(), wellknown.NewRegistration(), @@ -142,17 +142,17 @@ func startServices(ctx context.Context, cfg config.Config, otdf *server.OpenTDFS for _, svc := range namespace.Services { // Get new db client if it is required and not already created - if svc.DB.Required && svcDBClient == nil { + if svc.IsDBRequired() && svcDBClient == nil { logger.Debug("creating database client", slog.String("namespace", ns)) var err error - svcDBClient, err = newServiceDBClient(ctx, cfg.Logger, cfg.DB, ns, svc.DB.Migrations) + svcDBClient, err = newServiceDBClient(ctx, cfg.Logger, cfg.DB, ns, svc.DBMigrations()) if err != nil { return err } } err = svc.Start(ctx, serviceregistry.RegistrationParams{ - Config: cfg.Services[svc.Namespace], + Config: cfg.Services[svc.GetNamespace()], Logger: svcLogger, DBClient: svcDBClient, SDK: client, @@ -174,7 +174,7 @@ func startServices(ctx context.Context, cfg config.Config, otdf *server.OpenTDFS } // Register the service with the gRPC gateway - if err := svc.RegisterHTTPServer(ctx, otdf.Mux); err != nil { //nolint:staticcheck // This is deprecated for internal tracking + if err := svc.RegisterHTTPServer(ctx, otdf.Mux); err != nil { logger.Error("failed to register service to grpc gateway", slog.String("namespace", ns), slog.String("error", err.Error())) return err } @@ -182,7 +182,7 @@ func startServices(ctx context.Context, cfg config.Config, otdf *server.OpenTDFS logger.Info( "service running", slog.String("namespace", ns), - slog.String("service", svc.ServiceDesc.ServiceName), + slog.String("service", svc.GetServiceDesc().ServiceName), slog.Group("database", slog.Any("required", svcDBClient != nil), slog.Any("migrationStatus", determineStatusOfMigration(svcDBClient)), diff --git a/service/pkg/server/services_test.go b/service/pkg/server/services_test.go index ee50ab6fc..1f9037f40 100644 --- a/service/pkg/server/services_test.go +++ b/service/pkg/server/services_test.go @@ -22,17 +22,17 @@ type mockTestServiceOptions struct { serviceName string serviceHandlerType any serviceObject any - serviceHandler func(ctx context.Context, mux *runtime.ServeMux, server any) error + serviceHandler func(ctx context.Context, mux *runtime.ServeMux) error dbRegister serviceregistry.DBRegister } -func mockTestServiceRegistry(opts mockTestServiceOptions) (serviceregistry.Registration, *spyTestService) { +func mockTestServiceRegistry(opts mockTestServiceOptions) (serviceregistry.IService, *spyTestService) { spy := &spyTestService{} mockTestServiceDefaults := mockTestServiceOptions{ namespace: "test", serviceName: "TestService", serviceHandlerType: (*interface{})(nil), - serviceHandler: func(_ context.Context, _ *runtime.ServeMux, _ any) error { + serviceHandler: func(_ context.Context, _ *runtime.ServeMux) error { return nil }, } @@ -52,21 +52,28 @@ func mockTestServiceRegistry(opts mockTestServiceOptions) (serviceregistry.Regis serviceHandler = opts.serviceHandler } - return serviceregistry.Registration{ - Namespace: namespace, - ServiceDesc: &grpc.ServiceDesc{ - ServiceName: serviceName, - HandlerType: serviceHandlerType, + return &serviceregistry.Service[TestService]{ + ServiceOptions: serviceregistry.ServiceOptions[TestService]{ + Namespace: namespace, + ServiceDesc: &grpc.ServiceDesc{ + ServiceName: serviceName, + HandlerType: serviceHandlerType, + }, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*TestService, serviceregistry.HandlerServer) { + var ts *TestService + var ok bool + if ts, ok = opts.serviceObject.(*TestService); !ok { + panic("serviceObject is not a TestService") + } + return ts, func(ctx context.Context, mux *runtime.ServeMux) error { + spy.wasCalled = true + spy.callParams = append(spy.callParams, srp, ctx, mux, ts) + return serviceHandler(ctx, mux) + } + }, + + DB: opts.dbRegister, }, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return opts.serviceObject, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - spy.wasCalled = true - spy.callParams = append(spy.callParams, srp, ctx, mux, server) - return serviceHandler(ctx, mux, server) - } - }, - - DB: opts.dbRegister, }, spy } diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 1b3af7c8b..8a82dec64 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -126,9 +126,9 @@ func Start(f ...StartOptions) error { if len(startConfig.extraServices) > 0 { logger.Debug("registering extra services") for _, service := range startConfig.extraServices { - err := svcRegistry.RegisterService(service, service.Namespace) + err := svcRegistry.RegisterService(service, service.GetNamespace()) if err != nil { - logger.Error("could not register extra service", slog.String("namespace", service.Namespace), slog.String("error", err.Error())) + logger.Error("could not register extra service", slog.String("namespace", service.GetNamespace()), slog.String("error", err.Error())) return fmt.Errorf("could not register extra service: %w", err) } } diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index 0d6a6c51c..5ab8ed314 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -2,15 +2,13 @@ package server import ( "context" - "fmt" "io" + "log/slog" "net/http" "net/http/httptest" "testing" "time" - "log/slog" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/internal/config" @@ -22,8 +20,10 @@ import ( "github.com/stretchr/testify/suite" ) -type TestServiceService interface{} -type TestService struct{} +type ( + TestServiceService interface{} + TestService struct{} +) func (t TestService) TestHandler(w http.ResponseWriter, _ *http.Request, _ map[string]string) { _, err := w.Write([]byte("hello from test service!")) @@ -99,14 +99,11 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered_Expect_Res require.NoError(t, err) // Register Test Service + ts := &TestService{} registerTestService, _ := mockTestServiceRegistry(mockTestServiceOptions{ - serviceObject: &TestService{}, - serviceHandler: func(_ context.Context, mux *runtime.ServeMux, server any) error { - t, ok := server.(*TestService) - if !ok { - return fmt.Errorf("Surprise! Not a TestService") - } - return mux.HandlePath(http.MethodGet, "/healthz", t.TestHandler) + serviceObject: ts, + serviceHandler: func(_ context.Context, mux *runtime.ServeMux) error { + return mux.HandlePath(http.MethodGet, "/healthz", ts.TestHandler) }, }) diff --git a/service/pkg/serviceregistry/serviceregistry.go b/service/pkg/serviceregistry/serviceregistry.go index a5c1f7157..2d9ca082e 100644 --- a/service/pkg/serviceregistry/serviceregistry.go +++ b/service/pkg/serviceregistry/serviceregistry.go @@ -5,8 +5,10 @@ import ( "embed" "fmt" "log/slog" + "net/http" "slices" + "connectrpc.com/connect" "github.com/opentdf/platform/sdk" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -47,23 +49,10 @@ type RegistrationParams struct { // ready to serve requests. This function should be called in the RegisterFunc function. RegisterReadinessCheck func(namespace string, check func(context.Context) error) error } -type HandlerServer func(ctx context.Context, mux *runtime.ServeMux, server any) error -type RegisterFunc func(RegistrationParams) (Impl any, HandlerServer HandlerServer) - -// Registration is a struct that holds the information needed to register a service -type Registration struct { - // Namespace is the namespace of the service. One or more gRPC services can be registered under - // the same namespace. - Namespace string - // ServiceDesc is the gRPC service descriptor. For non-gRPC services, this can be mocked out, - // but at minimum, the ServiceName field must be set - ServiceDesc *grpc.ServiceDesc - // RegisterFunc is the function that will be called to register the service - RegisterFunc RegisterFunc - - // DB is optional and used to register the service with a database - DB DBRegister -} +type ( + HandlerServer func(ctx context.Context, mux *runtime.ServeMux) error + RegisterFunc[S any] func(RegistrationParams) (Impl *S, HandlerServer HandlerServer) +) // DBRegister is a struct that holds the information needed to register a service with a database type DBRegister struct { @@ -75,21 +64,77 @@ type DBRegister struct { Migrations *embed.FS } +type IService interface { + IsDBRequired() bool + DBMigrations() *embed.FS + GetNamespace() string + GetServiceDesc() *grpc.ServiceDesc + Start(ctx context.Context, params RegistrationParams) error + IsStarted() bool + Shutdown() error + RegisterGRPCServer(server *grpc.Server) error + RegisterHTTPServer(ctx context.Context, mux *runtime.ServeMux) error +} + // Service is a struct that holds the registration information for a service as well as the state // of the service within the instance of the platform. -type Service struct { - Registration - impl any - handleFunc HandlerServer +type Service[S any] struct { + // Registration + impl *S // Started is a flag that indicates whether the service has been started Started bool // Close is a function that can be called to close the service Close func() + // Service Options + ServiceOptions[S] +} + +type ServiceOptions[S any] struct { + // Namespace is the namespace of the service. One or more gRPC services can be registered under + // the same namespace. + Namespace string + // ServiceDesc is the gRPC service descriptor. For non-gRPC services, this can be mocked out, + // but at minimum, the ServiceName field must be set + ServiceDesc *grpc.ServiceDesc + // RegisterFunc is the function that will be called to register the service + RegisterFunc RegisterFunc[S] + httpHandlerFunc HandlerServer + // ConnectRPCServiceHandler is the function that will be called to register the service with the + ConnectRPCFunc func(S, ...connect.HandlerOption) (string, http.Handler) + // DB is optional and used to register the service with a database + DB DBRegister +} + +func (s Service[S]) GetNamespace() string { + return s.Namespace +} + +func (s Service[S]) GetServiceDesc() *grpc.ServiceDesc { + return s.ServiceDesc +} + +func (s Service[S]) IsStarted() bool { + return s.Started +} + +func (s Service[S]) Shutdown() error { + if s.Close != nil { + s.Close() + } + return nil +} + +func (s Service[S]) IsDBRequired() bool { + return s.DB.Required +} + +func (s Service[S]) DBMigrations() *embed.FS { + return s.DB.Migrations } // Start starts the service and performs necessary initialization steps. // It returns an error if the service is already started or if there is an issue running database migrations. -func (s *Service) Start(ctx context.Context, params RegistrationParams) error { +func (s *Service[S]) Start(ctx context.Context, params RegistrationParams) error { if s.Started { return fmt.Errorf("service already started") } @@ -104,7 +149,7 @@ func (s *Service) Start(ctx context.Context, params RegistrationParams) error { ) } - s.impl, s.handleFunc = s.RegisterFunc(params) + s.impl, s.httpHandlerFunc = s.RegisterFunc(params) s.Started = true return nil @@ -113,7 +158,7 @@ func (s *Service) Start(ctx context.Context, params RegistrationParams) error { // RegisterGRPCServer registers the gRPC server with the service implementation. // It checks if the service implementation is registered and then registers the service with the server. // It returns an error if the service implementation is not registered. -func (s *Service) RegisterGRPCServer(server *grpc.Server) error { +func (s *Service[S]) RegisterGRPCServer(server *grpc.Server) error { if s.impl == nil { return fmt.Errorf("service did not register an implementation") } @@ -126,17 +171,17 @@ func (s *Service) RegisterGRPCServer(server *grpc.Server) error { // RegisterHTTPServer registers an HTTP server with the service. // It takes a context, a ServeMux, and an implementation function as parameters. // If the service did not register a handler, it returns an error. -func (s *Service) RegisterHTTPServer(ctx context.Context, mux *runtime.ServeMux) error { - if s.handleFunc == nil { +func (s *Service[S]) RegisterHTTPServer(ctx context.Context, mux *runtime.ServeMux) error { + if s.httpHandlerFunc == nil { return fmt.Errorf("service did not register a handler") } - return s.handleFunc(ctx, mux, s.impl) + return s.httpHandlerFunc(ctx, mux) } // namespace represents a namespace in the service registry. type Namespace struct { Mode string - Services []Service + Services []IService } // Registry represents a map of service namespaces. @@ -150,8 +195,8 @@ func NewServiceRegistry() Registry { // RegisterCoreService registers a core service with the given registration information. // It calls the RegisterService method of the Registry instance with the provided registration and service type "core". // Returns an error if the registration fails. -func (reg Registry) RegisterCoreService(r Registration) error { - return reg.RegisterService(r, "core") +func (reg Registry) RegisterCoreService(svc IService) error { + return reg.RegisterService(svc, "core") } // RegisterService registers a service in the service registry. @@ -160,27 +205,25 @@ func (reg Registry) RegisterCoreService(r Registration) error { // such as the namespace and service description. // The mode string specifies the mode in which the service should be registered. // It returns an error if the service is already registered in the specified namespace. -func (reg Registry) RegisterService(r Registration, mode string) error { +func (reg Registry) RegisterService(svc IService, mode string) error { // Can't directly modify structs within a map, so we need to copy the namespace - copyNamespace := reg[r.Namespace] + copyNamespace := reg[svc.GetNamespace()] copyNamespace.Mode = mode if copyNamespace.Services == nil { - copyNamespace.Services = make([]Service, 0) + copyNamespace.Services = make([]IService, 0) } - found := slices.ContainsFunc(reg[r.Namespace].Services, func(s Service) bool { - return s.ServiceDesc.ServiceName == r.ServiceDesc.ServiceName + found := slices.ContainsFunc(reg[svc.GetNamespace()].Services, func(s IService) bool { + return s.GetServiceDesc().ServiceName == svc.GetServiceDesc().ServiceName }) if found { - return fmt.Errorf("service already registered namespace:%s service:%s", r.Namespace, r.ServiceDesc.ServiceName) + return fmt.Errorf("service already registered namespace:%s service:%s", svc.GetNamespace(), svc.GetServiceDesc().ServiceName) } - slog.Info("registered service", slog.String("namespace", r.Namespace), slog.String("service", r.ServiceDesc.ServiceName)) - copyNamespace.Services = append(copyNamespace.Services, Service{ - Registration: r, - }) + slog.Info("registered service", slog.String("namespace", svc.GetNamespace()), slog.String("service", svc.GetServiceDesc().ServiceName)) + copyNamespace.Services = append(copyNamespace.Services, svc) - reg[r.Namespace] = copyNamespace + reg[svc.GetNamespace()] = copyNamespace return nil } @@ -191,9 +234,11 @@ func (reg Registry) RegisterService(r Registration, mode string) error { func (reg Registry) Shutdown() { for name, ns := range reg { for _, svc := range ns.Services { - if svc.Close != nil && svc.Started { - slog.Info("stopping service", slog.String("namespace", name), slog.String("service", svc.ServiceDesc.ServiceName)) - svc.Close() + if svc.IsStarted() { + slog.Info("stopping service", slog.String("namespace", name), slog.String("service", svc.GetServiceDesc().ServiceName)) + if err := svc.Shutdown(); err != nil { + slog.Error("error stopping service", slog.String("namespace", name), slog.String("service", svc.GetServiceDesc().ServiceName), slog.String("error", err.Error())) + } } } } diff --git a/service/policy/attributes/attributes.go b/service/policy/attributes/attributes.go index 2bae2efe2..c3d31aa83 100644 --- a/service/policy/attributes/attributes.go +++ b/service/policy/attributes/attributes.go @@ -21,16 +21,18 @@ type AttributesService struct { //nolint:revive // AttributesService is a valid logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &attributes.AttributesService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &AttributesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - if srv, ok := server.(attributes.AttributesServiceServer); ok { - return attributes.RegisterAttributesServiceHandlerServer(ctx, mux, srv) +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[AttributesService] { + return &serviceregistry.Service[AttributesService]{ + ServiceOptions: serviceregistry.ServiceOptions[AttributesService]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &attributes.AttributesService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*AttributesService, serviceregistry.HandlerServer) { + as := &AttributesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + return as, func(ctx context.Context, mux *runtime.ServeMux) error { + return attributes.RegisterAttributesServiceHandlerServer(ctx, mux, as) } - return fmt.Errorf("failed to assert server as attributes.AttributesServiceServer") - } + }, }, } } diff --git a/service/policy/kasregistry/key_access_server_registry.go b/service/policy/kasregistry/key_access_server_registry.go index 3cb4bc2e0..db821d1ac 100644 --- a/service/policy/kasregistry/key_access_server_registry.go +++ b/service/policy/kasregistry/key_access_server_registry.go @@ -2,7 +2,6 @@ package kasregistry import ( "context" - "fmt" "log/slog" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -21,17 +20,18 @@ type KeyAccessServerRegistry struct { logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &kasr.KeyAccessServerRegistryService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &KeyAccessServerRegistry{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, s any) error { - srv, ok := s.(kasr.KeyAccessServerRegistryServiceServer) - if !ok { - return fmt.Errorf("argument is not of type kasr.KeyAccessServerRegistryServiceServer") +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[KeyAccessServerRegistry] { + return &serviceregistry.Service[KeyAccessServerRegistry]{ + ServiceOptions: serviceregistry.ServiceOptions[KeyAccessServerRegistry]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &kasr.KeyAccessServerRegistryService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*KeyAccessServerRegistry, serviceregistry.HandlerServer) { + ksr := &KeyAccessServerRegistry{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + return ksr, func(ctx context.Context, mux *runtime.ServeMux) error { + return kasr.RegisterKeyAccessServerRegistryServiceHandlerServer(ctx, mux, ksr) } - return kasr.RegisterKeyAccessServerRegistryServiceHandlerServer(ctx, mux, srv) - } + }, }, } } diff --git a/service/policy/namespaces/namespaces.go b/service/policy/namespaces/namespaces.go index 5598f5880..f9ca32f53 100644 --- a/service/policy/namespaces/namespaces.go +++ b/service/policy/namespaces/namespaces.go @@ -21,23 +21,23 @@ type NamespacesService struct { //nolint:revive // NamespacesService is a valid logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &namespaces.NamespaceService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - ns := &NamespacesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} - - if err := srp.RegisterReadinessCheck("policy", ns.IsReady); err != nil { - srp.Logger.Error("failed to register policy readiness check", slog.String("error", err.Error())) - } - - return ns, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - nsServer, ok := server.(namespaces.NamespaceServiceServer) - if !ok { - return fmt.Errorf("failed to assert server as namespaces.NamespaceServiceServer") +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[NamespacesService] { + return &serviceregistry.Service[NamespacesService]{ + ServiceOptions: serviceregistry.ServiceOptions[NamespacesService]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &namespaces.NamespaceService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*NamespacesService, serviceregistry.HandlerServer) { + ns := &NamespacesService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + + if err := srp.RegisterReadinessCheck("policy", ns.IsReady); err != nil { + srp.Logger.Error("failed to register policy readiness check", slog.String("error", err.Error())) } - return namespaces.RegisterNamespaceServiceHandlerServer(ctx, mux, nsServer) - } + + return ns, func(ctx context.Context, mux *runtime.ServeMux) error { + return namespaces.RegisterNamespaceServiceHandlerServer(ctx, mux, ns) + } + }, }, } } diff --git a/service/policy/policy.go b/service/policy/policy.go index 3096f31d2..69db81e86 100644 --- a/service/policy/policy.go +++ b/service/policy/policy.go @@ -19,25 +19,21 @@ func init() { Migrations = &migrations.FS } -func NewRegistrations() []serviceregistry.Registration { - registrations := []serviceregistry.Registration{} +func NewRegistrations() []serviceregistry.IService { + registrations := []serviceregistry.IService{} namespace := "policy" dbRegister := serviceregistry.DBRegister{ Required: true, Migrations: Migrations, } - for _, r := range []serviceregistry.Registration{ - attributes.NewRegistration(), - namespaces.NewRegistration(), - resourcemapping.NewRegistration(), - subjectmapping.NewRegistration(), - kasregistry.NewRegistration(), - unsafe.NewRegistration(), - } { - r.Namespace = namespace - r.DB = dbRegister - registrations = append(registrations, r) - } + registrations = append(registrations, []serviceregistry.IService{ + attributes.NewRegistration(namespace, dbRegister), + namespaces.NewRegistration(namespace, dbRegister), + resourcemapping.NewRegistration(namespace, dbRegister), + subjectmapping.NewRegistration(namespace, dbRegister), + kasregistry.NewRegistration(namespace, dbRegister), + unsafe.NewRegistration(namespace, dbRegister), + }...) return registrations } diff --git a/service/policy/resourcemapping/resource_mapping.go b/service/policy/resourcemapping/resource_mapping.go index 58d4f273c..54e732172 100644 --- a/service/policy/resourcemapping/resource_mapping.go +++ b/service/policy/resourcemapping/resource_mapping.go @@ -2,7 +2,6 @@ package resourcemapping import ( "context" - "fmt" "log/slog" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -21,17 +20,18 @@ type ResourceMappingService struct { //nolint:revive // ResourceMappingService i logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &resourcemapping.ResourceMappingService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &ResourceMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, s any) error { - server, ok := s.(resourcemapping.ResourceMappingServiceServer) - if !ok { - return fmt.Errorf("failed to assert server as resourcemapping.ResourceMappingServiceServer") +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[ResourceMappingService] { + return &serviceregistry.Service[ResourceMappingService]{ + ServiceOptions: serviceregistry.ServiceOptions[ResourceMappingService]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &resourcemapping.ResourceMappingService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*ResourceMappingService, serviceregistry.HandlerServer) { + rm := &ResourceMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + return rm, func(ctx context.Context, mux *runtime.ServeMux) error { + return resourcemapping.RegisterResourceMappingServiceHandlerServer(ctx, mux, rm) } - return resourcemapping.RegisterResourceMappingServiceHandlerServer(ctx, mux, server) - } + }, }, } } diff --git a/service/policy/subjectmapping/subject_mapping.go b/service/policy/subjectmapping/subject_mapping.go index 3413ae4e4..a44842a18 100644 --- a/service/policy/subjectmapping/subject_mapping.go +++ b/service/policy/subjectmapping/subject_mapping.go @@ -2,7 +2,6 @@ package subjectmapping import ( "context" - "fmt" "log/slog" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -21,17 +20,18 @@ type SubjectMappingService struct { //nolint:revive // SubjectMappingService is logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &sm.SubjectMappingService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &SubjectMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, s any) error { - server, ok := s.(sm.SubjectMappingServiceServer) - if !ok { - return fmt.Errorf("failed to assert server as sm.SubjectMappingServiceServer") +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[SubjectMappingService] { + return &serviceregistry.Service[SubjectMappingService]{ + ServiceOptions: serviceregistry.ServiceOptions[SubjectMappingService]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &sm.SubjectMappingService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*SubjectMappingService, serviceregistry.HandlerServer) { + smSvc := &SubjectMappingService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + return smSvc, func(ctx context.Context, mux *runtime.ServeMux) error { + return sm.RegisterSubjectMappingServiceHandlerServer(ctx, mux, smSvc) } - return sm.RegisterSubjectMappingServiceHandlerServer(ctx, mux, server) - } + }, }, } } diff --git a/service/policy/unsafe/unsafe.go b/service/policy/unsafe/unsafe.go index bfade319f..ed1f2700e 100644 --- a/service/policy/unsafe/unsafe.go +++ b/service/policy/unsafe/unsafe.go @@ -2,7 +2,6 @@ package unsafe import ( "context" - "fmt" "log/slog" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -21,16 +20,18 @@ type UnsafeService struct { //nolint:revive // UnsafeService is a valid name for logger *logger.Logger } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - ServiceDesc: &unsafe.UnsafeService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &UnsafeService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - if srv, ok := server.(unsafe.UnsafeServiceServer); ok { - return unsafe.RegisterUnsafeServiceHandlerServer(ctx, mux, srv) +func NewRegistration(ns string, dbRegister serviceregistry.DBRegister) *serviceregistry.Service[UnsafeService] { + return &serviceregistry.Service[UnsafeService]{ + ServiceOptions: serviceregistry.ServiceOptions[UnsafeService]{ + Namespace: ns, + DB: dbRegister, + ServiceDesc: &unsafe.UnsafeService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*UnsafeService, serviceregistry.HandlerServer) { + unsafeSvc := &UnsafeService{dbClient: policydb.NewClient(srp.DBClient, srp.Logger), logger: srp.Logger} + return unsafeSvc, func(ctx context.Context, mux *runtime.ServeMux) error { + return unsafe.RegisterUnsafeServiceHandlerServer(ctx, mux, unsafeSvc) } - return fmt.Errorf("failed to assert server as unsafe.UnsafeServiceServer") - } + }, }, } } diff --git a/service/wellknownconfiguration/wellknown_configuration.go b/service/wellknownconfiguration/wellknown_configuration.go index dbfe045a1..bbfb34e3a 100644 --- a/service/wellknownconfiguration/wellknown_configuration.go +++ b/service/wellknownconfiguration/wellknown_configuration.go @@ -35,17 +35,17 @@ func RegisterConfiguration(namespace string, config any) error { return nil } -func NewRegistration() serviceregistry.Registration { - return serviceregistry.Registration{ - Namespace: "wellknown", - ServiceDesc: &wellknown.WellKnownService_ServiceDesc, - RegisterFunc: func(srp serviceregistry.RegistrationParams) (any, serviceregistry.HandlerServer) { - return &WellKnownService{logger: srp.Logger}, func(ctx context.Context, mux *runtime.ServeMux, server any) error { - if srv, ok := server.(wellknown.WellKnownServiceServer); ok { - return wellknown.RegisterWellKnownServiceHandlerServer(ctx, mux, srv) +func NewRegistration() *serviceregistry.Service[WellKnownService] { + return &serviceregistry.Service[WellKnownService]{ + ServiceOptions: serviceregistry.ServiceOptions[WellKnownService]{ + Namespace: "wellknown", + ServiceDesc: &wellknown.WellKnownService_ServiceDesc, + RegisterFunc: func(srp serviceregistry.RegistrationParams) (*WellKnownService, serviceregistry.HandlerServer) { + wk := &WellKnownService{logger: srp.Logger} + return wk, func(ctx context.Context, mux *runtime.ServeMux) error { + return wellknown.RegisterWellKnownServiceHandlerServer(ctx, mux, wk) } - return fmt.Errorf("failed to assert server as WellKnownServiceServer") - } + }, }, } }