diff --git a/service/go.mod b/service/go.mod index 8f42d83df..00ad68f06 100644 --- a/service/go.mod +++ b/service/go.mod @@ -151,7 +151,7 @@ require ( github.com/yusufpapurcu/wmi v1.2.3 // indirect go.uber.org/multierr v1.11.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 sigs.k8s.io/yaml v1.4.0 // indirect ) diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index 6624d3117..92c7b49ea 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -157,11 +157,15 @@ func Start(f ...StartOptions) error { oidcconfig *auth.OIDCConfiguration ) - // If the mode is not all or entityresolution, we need to have a valid SDK config + // If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config // entityresolution does not connect to other services and can run on its own - if !slices.Contains(cfg.Mode, "all") && !slices.Contains(cfg.Mode, "entityresolution") && cfg.SDKConfig == (config.SDKConfig{}) { - logger.Error("mode is not all or entityresolution, but no sdk config provided") - return errors.New("mode is not all or entityresolution, but no sdk config provided") + // core only connects to entityresolution + if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode + (slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor + (slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own + cfg.SDKConfig == (config.SDKConfig{}) { + logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") + return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided") } // If client credentials are provided, use them @@ -186,7 +190,7 @@ func Start(f ...StartOptions) error { sdkOptions = append(sdkOptions, sdk.WithCustomCoreConnection(otdf.ConnectRPCInProcess.Conn())) // handle ERS connection for core mode - if slices.Contains(cfg.Mode, "core") { + if slices.Contains(cfg.Mode, "core") && !slices.Contains(cfg.Mode, "entityresolution") { logger.Info("core mode") if cfg.SDKConfig.EntityResolutionConnection.Endpoint == "" { diff --git a/service/pkg/server/start_test.go b/service/pkg/server/start_test.go index db66b50b2..bd54906df 100644 --- a/service/pkg/server/start_test.go +++ b/service/pkg/server/start_test.go @@ -2,10 +2,13 @@ package server import ( "context" + "fmt" "io" "log/slog" "net/http" "net/http/httptest" + "os" + "strings" "testing" "time" @@ -18,6 +21,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "gopkg.in/yaml.v3" ) type ( @@ -32,9 +36,8 @@ func (t TestService) TestHandler(w http.ResponseWriter, _ *http.Request, _ map[s } } -func mockOpenTDFServer() (*server.OpenTDFServer, error) { +func mockKeycloakServer() *httptest.Server { discoveryURL := "not set yet" - discoveryEndpoint := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { var resp string @@ -62,6 +65,11 @@ func mockOpenTDFServer() (*server.OpenTDFServer, error) { discoveryURL = discoveryEndpoint.URL + return discoveryEndpoint +} + +func mockOpenTDFServer() (*server.OpenTDFServer, error) { + discoveryEndpoint := mockKeycloakServer() // Create new opentdf server return server.NewOpenTDFServer(server.Config{ WellKnownConfigRegister: func(_ string, _ any) error { @@ -82,6 +90,70 @@ func mockOpenTDFServer() (*server.OpenTDFServer, error) { ) } +func updateNestedKey(data map[string]interface{}, path []string, value interface{}) error { + if len(path) == 0 { + return fmt.Errorf("path cannot be empty") + } + + current := data + for i, key := range path[:len(path)-1] { + if next, ok := current[key]; ok { + if nextMap, ok2 := next.(map[string]interface{}); ok2 { + current = nextMap + } else { + return fmt.Errorf("key %s at path level %d is not a map", key, i) + } + } else { + // If the key doesn't exist, initialize a new map + newMap := make(map[string]interface{}) + current[key] = newMap + current = newMap + } + } + + // Set the value at the final key + current[path[len(path)-1]] = value + return nil +} + +func createTempYAMLFileWithNestedChanges(changes map[string]interface{}, originalFilePath string, newFileName string) (string, error) { + // Load the original YAML file + data, err := os.ReadFile(originalFilePath) + if err != nil { + return "", err + } + + var yamlData map[string]interface{} + if err := yaml.Unmarshal(data, &yamlData); err != nil { + return "", err + } + + // Apply all changes + for keyPath, value := range changes { + path := strings.Split(keyPath, ".") // Convert dot notation to slice + if err := updateNestedKey(yamlData, path, value); err != nil { + return "", err + } + } + + // Create a temporary file + tempFile, err := os.CreateTemp("testdata", newFileName) + if err != nil { + return "", err + } + defer tempFile.Close() + + // Write the modified YAML to the temp file + encoder := yaml.NewEncoder(tempFile) + defer encoder.Close() + + if err := encoder.Encode(&yamlData); err != nil { + return "", err + } + + return tempFile.Name(), nil +} + type StartTestSuite struct { suite.Suite } @@ -142,3 +214,109 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered_Expect_Res require.NoError(t, err) assert.Equal(t, "hello from test service!", string(respBody)) } + +func (suite *StartTestSuite) Test_Start_Mode_Config_Errors() { + t := suite.T() + discoveryEndpoint := mockKeycloakServer() + originalFilePath := "testdata/all-no-config.yaml" + testCases := []struct { + name string + changes map[string]interface{} + newConfigFile string + expErrorContains string + }{ + {"core without sdk_config", + map[string]interface{}{ + "mode": "core", "server.auth.issuer": discoveryEndpoint.URL}, + "err-core-no-config-*.yaml", "no sdk config provided"}, + {"kas without sdk_config", + map[string]interface{}{ + "mode": "kas", "server.auth.issuer": discoveryEndpoint.URL}, + "err-kas-no-config-*.yaml", "no sdk config provided"}, + {"core with sdk_config without ers endpoint", + map[string]interface{}{ + "mode": "core", "server.auth.issuer": discoveryEndpoint.URL, + "sdk_config.client_id": "opentdf", "sdk_config.client_secret": "opentdf"}, + "err-core-w-config-no-ers-*.yaml", "entityresolution endpoint must be provided in core mode"}, + } + var tempFiles []string + defer func() { + // Cleanup all created temp files + for _, tempFile := range tempFiles { + if err := os.Remove(tempFile); err != nil { + t.Errorf("Failed to remove temp file %s: %v", tempFile, err) + } + } + }() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tempFilePath, err := createTempYAMLFileWithNestedChanges(tc.changes, originalFilePath, tc.newConfigFile) + if err != nil { + t.Fatalf("Failed to create temp YAML file: %v", err) + } + tempFiles = append(tempFiles, tempFilePath) + + err = Start( + WithConfigFile(tempFilePath), + ) + require.Error(t, err) + require.ErrorContains(t, err, tc.expErrorContains) + }) + } +} + +func (suite *StartTestSuite) Test_Start_Mode_Config_Success() { + t := suite.T() + discoveryEndpoint := mockKeycloakServer() + // require.NoError(t, err) + originalFilePath := "testdata/all-no-config.yaml" + testCases := []struct { + name string + changes map[string]interface{} + newConfigFile string + }{ + {"all without sdk_config", + map[string]interface{}{ + "server.auth.issuer": discoveryEndpoint.URL}, + "all-no-config-*.yaml"}, + {"core,entityresolution without sdk_config", + map[string]interface{}{ + "mode": "core,entityresolution", "server.auth.issuer": discoveryEndpoint.URL}, + "all-no-config-*.yaml"}, + {"core,entityresolution,kas without sdk_config", + map[string]interface{}{ + "mode": "core,entityresolution,kas", "server.auth.issuer": discoveryEndpoint.URL}, + "all-no-config-*.yaml"}, + {"core with correct sdk_config", + map[string]interface{}{ + "mode": "core", "server.auth.issuer": discoveryEndpoint.URL, + "sdk_config.client_id": "opentdf", "sdk_config.client_secret": "opentdf", + "sdk_config.entityresolution.endpoint": "http://localhost:8181", "sdk_config.entityresolution.plaintext": "true"}, + "core-w-config-correct-*.yaml"}, + } + var tempFiles []string + defer func() { + // Cleanup all created temp files + for _, tempFile := range tempFiles { + if err := os.Remove(tempFile); err != nil { + t.Errorf("Failed to remove temp file %s: %v", tempFile, err) + } + } + }() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tempFilePath, err := createTempYAMLFileWithNestedChanges(tc.changes, originalFilePath, tc.newConfigFile) + if err != nil { + t.Fatalf("Failed to create temp YAML file: %v", err) + } + tempFiles = append(tempFiles, tempFilePath) + + err = Start( + WithConfigFile(tempFilePath), + ) + // require that it got past the service config and mode setup + // expected error when trying to establish db connection + require.ErrorContains(t, err, "failed to connect to database") + }) + } +} diff --git a/service/pkg/server/testdata/all-no-config.yaml b/service/pkg/server/testdata/all-no-config.yaml new file mode 100644 index 000000000..ff4e1f299 --- /dev/null +++ b/service/pkg/server/testdata/all-no-config.yaml @@ -0,0 +1,100 @@ + +mode: all +logger: + level: debug + type: text + output: stdout +services: + kas: + keyring: + - kid: e1 + alg: ec:secp256r1 + - kid: e1 + alg: ec:secp256r1 + legacy: true + - kid: r1 + alg: rsa:2048 + - kid: r1 + alg: rsa:2048 + legacy: true + entityresolution: + log_level: info + url: http://localhost:8888/auth + clientid: 'tdf-entity-resolution' + clientsecret: 'secret' + realm: 'opentdf' + legacykeycloak: true + inferid: + from: + email: true + username: true +server: + tls: + enabled: false + cert: ./keys/platform.crt + key: ./keys/platform-key.pem + auth: + enabled: true + enforceDPoP: false + public_client_id: 'opentdf-public' + audience: 'http://localhost:8080' + issuer: http://localhost:8888/auth/realms/opentdf + policy: + ## Dot notation is used to access nested claims (i.e. realm_access.roles) + # Claim that represents the user (i.e. email) + username_claim: # preferred_username + # That claim to access groups (i.e. realm_access.roles) + groups_claim: # realm_access.roles + ## Extends the builtin policy + extension: | + g, opentdf-admin, role:admin + g, opentdf-standard, role:standard + ## Custom policy that overrides builtin policy (see examples https://github.com/casbin/casbin/tree/master/examples) + csv: #| + # p, role:admin, *, *, allow + ## Custom model (see https://casbin.org/docs/syntax-for-models/) + model: #| + # [request_definition] + # r = sub, res, act, obj + # + # [policy_definition] + # p = sub, res, act, obj, eft + # + # [role_definition] + # g = _, _ + # + # [policy_effect] + # e = some(where (p.eft == allow)) && !some(where (p.eft == deny)) + # + # [matchers] + # m = g(r.sub, p.sub) && globOrRegexMatch(r.res, p.res) && globOrRegexMatch(r.act, p.act) && globOrRegexMatch(r.obj, p.obj) + cors: + enabled: false + # "*" to allow any origin or a specific domain like "https://yourdomain.com" + allowedorigins: + - '*' + # List of methods. Examples: "GET,POST,PUT" + allowedmethods: + - GET + - POST + - PATCH + - PUT + - DELETE + - OPTIONS + # List of headers that are allowed in a request + allowedheaders: + - ACCEPT + - Authorization + - Content-Type + - X-CSRF-Token + - X-Request-ID + # List of response headers that browsers are allowed to access + exposedheaders: + - Link + # Sets whether credentials are included in the CORS request + allowcredentials: true + # Sets the maximum age (in seconds) of a specific CORS preflight request + maxage: 3600 + grpc: + reflectionEnabled: true # Default is false + port: 8080