Skip to content

Commit

Permalink
fix(core): Handle multiple modes including entityresolution mode (#1816)
Browse files Browse the repository at this point in the history
### Proposed Changes

* handle sdk_config error throwing when mulltiple modes are present
* need to sdk config when running in kas or core without ers

### Checklist

- [ ] I have added or updated unit tests
- [ ] I have added or updated integration tests (if appropriate)
- [ ] I have added or updated documentation

### Testing Instructions
  • Loading branch information
elizabethhealy authored Dec 6, 2024
1 parent 4d47475 commit 32d6938
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 8 deletions.
2 changes: 1 addition & 1 deletion service/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.3 // indirect
go.uber.org/multierr v1.11.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.1
sigs.k8s.io/yaml v1.4.0 // indirect
)

Expand Down
14 changes: 9 additions & 5 deletions service/pkg/server/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,15 @@ func Start(f ...StartOptions) error {
oidcconfig *auth.OIDCConfiguration
)

// If the mode is not all or entityresolution, we need to have a valid SDK config
// If the mode is not all, does not include both core and entityresolution, or is not entityresolution on its own, we need to have a valid SDK config
// entityresolution does not connect to other services and can run on its own
if !slices.Contains(cfg.Mode, "all") && !slices.Contains(cfg.Mode, "entityresolution") && cfg.SDKConfig == (config.SDKConfig{}) {
logger.Error("mode is not all or entityresolution, but no sdk config provided")
return errors.New("mode is not all or entityresolution, but no sdk config provided")
// core only connects to entityresolution
if !(slices.Contains(cfg.Mode, "all") || // no config required for all mode
(slices.Contains(cfg.Mode, "core") && slices.Contains(cfg.Mode, "entityresolution")) || // or core and entityresolution modes togethor
(slices.Contains(cfg.Mode, "entityresolution") && len(cfg.Mode) == 1)) && // or entityresolution on its own
cfg.SDKConfig == (config.SDKConfig{}) {
logger.Error("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided")
return errors.New("mode is not all, entityresolution, or a combination of core and entityresolution, but no sdk config provided")
}

// If client credentials are provided, use them
Expand All @@ -186,7 +190,7 @@ func Start(f ...StartOptions) error {
sdkOptions = append(sdkOptions, sdk.WithCustomCoreConnection(otdf.ConnectRPCInProcess.Conn()))

// handle ERS connection for core mode
if slices.Contains(cfg.Mode, "core") {
if slices.Contains(cfg.Mode, "core") && !slices.Contains(cfg.Mode, "entityresolution") {
logger.Info("core mode")

if cfg.SDKConfig.EntityResolutionConnection.Endpoint == "" {
Expand Down
182 changes: 180 additions & 2 deletions service/pkg/server/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package server

import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand All @@ -18,6 +21,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"gopkg.in/yaml.v3"
)

type (
Expand All @@ -32,9 +36,8 @@ func (t TestService) TestHandler(w http.ResponseWriter, _ *http.Request, _ map[s
}
}

func mockOpenTDFServer() (*server.OpenTDFServer, error) {
func mockKeycloakServer() *httptest.Server {
discoveryURL := "not set yet"

discoveryEndpoint := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var resp string
Expand Down Expand Up @@ -62,6 +65,11 @@ func mockOpenTDFServer() (*server.OpenTDFServer, error) {

discoveryURL = discoveryEndpoint.URL

return discoveryEndpoint
}

func mockOpenTDFServer() (*server.OpenTDFServer, error) {
discoveryEndpoint := mockKeycloakServer()
// Create new opentdf server
return server.NewOpenTDFServer(server.Config{
WellKnownConfigRegister: func(_ string, _ any) error {
Expand All @@ -82,6 +90,70 @@ func mockOpenTDFServer() (*server.OpenTDFServer, error) {
)
}

func updateNestedKey(data map[string]interface{}, path []string, value interface{}) error {
if len(path) == 0 {
return fmt.Errorf("path cannot be empty")
}

current := data
for i, key := range path[:len(path)-1] {
if next, ok := current[key]; ok {
if nextMap, ok2 := next.(map[string]interface{}); ok2 {
current = nextMap
} else {
return fmt.Errorf("key %s at path level %d is not a map", key, i)
}
} else {
// If the key doesn't exist, initialize a new map
newMap := make(map[string]interface{})
current[key] = newMap
current = newMap
}
}

// Set the value at the final key
current[path[len(path)-1]] = value
return nil
}

func createTempYAMLFileWithNestedChanges(changes map[string]interface{}, originalFilePath string, newFileName string) (string, error) {
// Load the original YAML file
data, err := os.ReadFile(originalFilePath)
if err != nil {
return "", err
}

var yamlData map[string]interface{}
if err := yaml.Unmarshal(data, &yamlData); err != nil {
return "", err
}

// Apply all changes
for keyPath, value := range changes {
path := strings.Split(keyPath, ".") // Convert dot notation to slice
if err := updateNestedKey(yamlData, path, value); err != nil {
return "", err
}
}

// Create a temporary file
tempFile, err := os.CreateTemp("testdata", newFileName)
if err != nil {
return "", err
}
defer tempFile.Close()

// Write the modified YAML to the temp file
encoder := yaml.NewEncoder(tempFile)
defer encoder.Close()

if err := encoder.Encode(&yamlData); err != nil {
return "", err
}

return tempFile.Name(), nil
}

type StartTestSuite struct {
suite.Suite
}
Expand Down Expand Up @@ -142,3 +214,109 @@ func (suite *StartTestSuite) Test_Start_When_Extra_Service_Registered_Expect_Res
require.NoError(t, err)
assert.Equal(t, "hello from test service!", string(respBody))
}

func (suite *StartTestSuite) Test_Start_Mode_Config_Errors() {
t := suite.T()
discoveryEndpoint := mockKeycloakServer()
originalFilePath := "testdata/all-no-config.yaml"
testCases := []struct {
name string
changes map[string]interface{}
newConfigFile string
expErrorContains string
}{
{"core without sdk_config",
map[string]interface{}{
"mode": "core", "server.auth.issuer": discoveryEndpoint.URL},
"err-core-no-config-*.yaml", "no sdk config provided"},
{"kas without sdk_config",
map[string]interface{}{
"mode": "kas", "server.auth.issuer": discoveryEndpoint.URL},
"err-kas-no-config-*.yaml", "no sdk config provided"},
{"core with sdk_config without ers endpoint",
map[string]interface{}{
"mode": "core", "server.auth.issuer": discoveryEndpoint.URL,
"sdk_config.client_id": "opentdf", "sdk_config.client_secret": "opentdf"},
"err-core-w-config-no-ers-*.yaml", "entityresolution endpoint must be provided in core mode"},
}
var tempFiles []string
defer func() {
// Cleanup all created temp files
for _, tempFile := range tempFiles {
if err := os.Remove(tempFile); err != nil {
t.Errorf("Failed to remove temp file %s: %v", tempFile, err)
}
}
}()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempFilePath, err := createTempYAMLFileWithNestedChanges(tc.changes, originalFilePath, tc.newConfigFile)
if err != nil {
t.Fatalf("Failed to create temp YAML file: %v", err)
}
tempFiles = append(tempFiles, tempFilePath)

err = Start(
WithConfigFile(tempFilePath),
)
require.Error(t, err)
require.ErrorContains(t, err, tc.expErrorContains)
})
}
}

func (suite *StartTestSuite) Test_Start_Mode_Config_Success() {
t := suite.T()
discoveryEndpoint := mockKeycloakServer()
// require.NoError(t, err)
originalFilePath := "testdata/all-no-config.yaml"
testCases := []struct {
name string
changes map[string]interface{}
newConfigFile string
}{
{"all without sdk_config",
map[string]interface{}{
"server.auth.issuer": discoveryEndpoint.URL},
"all-no-config-*.yaml"},
{"core,entityresolution without sdk_config",
map[string]interface{}{
"mode": "core,entityresolution", "server.auth.issuer": discoveryEndpoint.URL},
"all-no-config-*.yaml"},
{"core,entityresolution,kas without sdk_config",
map[string]interface{}{
"mode": "core,entityresolution,kas", "server.auth.issuer": discoveryEndpoint.URL},
"all-no-config-*.yaml"},
{"core with correct sdk_config",
map[string]interface{}{
"mode": "core", "server.auth.issuer": discoveryEndpoint.URL,
"sdk_config.client_id": "opentdf", "sdk_config.client_secret": "opentdf",
"sdk_config.entityresolution.endpoint": "http://localhost:8181", "sdk_config.entityresolution.plaintext": "true"},
"core-w-config-correct-*.yaml"},
}
var tempFiles []string
defer func() {
// Cleanup all created temp files
for _, tempFile := range tempFiles {
if err := os.Remove(tempFile); err != nil {
t.Errorf("Failed to remove temp file %s: %v", tempFile, err)
}
}
}()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tempFilePath, err := createTempYAMLFileWithNestedChanges(tc.changes, originalFilePath, tc.newConfigFile)
if err != nil {
t.Fatalf("Failed to create temp YAML file: %v", err)
}
tempFiles = append(tempFiles, tempFilePath)

err = Start(
WithConfigFile(tempFilePath),
)
// require that it got past the service config and mode setup
// expected error when trying to establish db connection
require.ErrorContains(t, err, "failed to connect to database")
})
}
}
100 changes: 100 additions & 0 deletions service/pkg/server/testdata/all-no-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

mode: all
logger:
level: debug
type: text
output: stdout
services:
kas:
keyring:
- kid: e1
alg: ec:secp256r1
- kid: e1
alg: ec:secp256r1
legacy: true
- kid: r1
alg: rsa:2048
- kid: r1
alg: rsa:2048
legacy: true
entityresolution:
log_level: info
url: http://localhost:8888/auth
clientid: 'tdf-entity-resolution'
clientsecret: 'secret'
realm: 'opentdf'
legacykeycloak: true
inferid:
from:
email: true
username: true
server:
tls:
enabled: false
cert: ./keys/platform.crt
key: ./keys/platform-key.pem
auth:
enabled: true
enforceDPoP: false
public_client_id: 'opentdf-public'
audience: 'http://localhost:8080'
issuer: http://localhost:8888/auth/realms/opentdf
policy:
## Dot notation is used to access nested claims (i.e. realm_access.roles)
# Claim that represents the user (i.e. email)
username_claim: # preferred_username
# That claim to access groups (i.e. realm_access.roles)
groups_claim: # realm_access.roles
## Extends the builtin policy
extension: |
g, opentdf-admin, role:admin
g, opentdf-standard, role:standard
## Custom policy that overrides builtin policy (see examples https://github.com/casbin/casbin/tree/master/examples)
csv: #|
# p, role:admin, *, *, allow
## Custom model (see https://casbin.org/docs/syntax-for-models/)
model: #|
# [request_definition]
# r = sub, res, act, obj
#
# [policy_definition]
# p = sub, res, act, obj, eft
#
# [role_definition]
# g = _, _
#
# [policy_effect]
# e = some(where (p.eft == allow)) && !some(where (p.eft == deny))
#
# [matchers]
# m = g(r.sub, p.sub) && globOrRegexMatch(r.res, p.res) && globOrRegexMatch(r.act, p.act) && globOrRegexMatch(r.obj, p.obj)
cors:
enabled: false
# "*" to allow any origin or a specific domain like "https://yourdomain.com"
allowedorigins:
- '*'
# List of methods. Examples: "GET,POST,PUT"
allowedmethods:
- GET
- POST
- PATCH
- PUT
- DELETE
- OPTIONS
# List of headers that are allowed in a request
allowedheaders:
- ACCEPT
- Authorization
- Content-Type
- X-CSRF-Token
- X-Request-ID
# List of response headers that browsers are allowed to access
exposedheaders:
- Link
# Sets whether credentials are included in the CORS request
allowcredentials: true
# Sets the maximum age (in seconds) of a specific CORS preflight request
maxage: 3600
grpc:
reflectionEnabled: true # Default is false
port: 8080

0 comments on commit 32d6938

Please sign in to comment.