diff --git a/clients/go/admin/auth_interceptor_test.go b/clients/go/admin/auth_interceptor_test.go index 029926902..d5e07a713 100644 --- a/clients/go/admin/auth_interceptor_test.go +++ b/clients/go/admin/auth_interceptor_test.go @@ -13,22 +13,18 @@ import ( "sync" "testing" - "github.com/flyteorg/flytestdlib/logger" - - "k8s.io/apimachinery/pkg/util/rand" - - mocks2 "github.com/flyteorg/flyteidl/clients/go/admin/mocks" - "github.com/stretchr/testify/mock" - - service2 "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/config" - "github.com/stretchr/testify/assert" - - "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" + "github.com/stretchr/testify/mock" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "k8s.io/apimachinery/pkg/util/rand" + + "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" + adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flytestdlib/config" + "github.com/flyteorg/flytestdlib/logger" ) // authMetadataServer is a fake AuthMetadataServer that takes in an AuthMetadataServer implementation (usually one @@ -39,15 +35,15 @@ type authMetadataServer struct { port int grpcServer *grpc.Server netListener net.Listener - impl service2.AuthMetadataServiceServer + impl service.AuthMetadataServiceServer lck *sync.RWMutex } -func (s authMetadataServer) GetOAuth2Metadata(ctx context.Context, in *service2.OAuth2MetadataRequest) (*service2.OAuth2MetadataResponse, error) { +func (s authMetadataServer) GetOAuth2Metadata(ctx context.Context, in *service.OAuth2MetadataRequest) (*service.OAuth2MetadataResponse, error) { return s.impl.GetOAuth2Metadata(ctx, in) } -func (s authMetadataServer) GetPublicClientConfig(ctx context.Context, in *service2.PublicClientAuthConfigRequest) (*service2.PublicClientAuthConfigResponse, error) { +func (s authMetadataServer) GetPublicClientConfig(ctx context.Context, in *service.PublicClientAuthConfigRequest) (*service.PublicClientAuthConfigResponse, error) { return s.impl.GetPublicClientConfig(ctx, in) } @@ -84,7 +80,7 @@ func (s *authMetadataServer) Start(_ context.Context) error { } grpcS := grpc.NewServer() - service2.RegisterAuthMetadataServiceServer(grpcS, s) + service.RegisterAuthMetadataServiceServer(grpcS, s) go func() { _ = grpcS.Serve(lis) //assert.NoError(s.t, err) @@ -106,7 +102,7 @@ func (s *authMetadataServer) Close() { s.s.Close() } -func newAuthMetadataServer(t testing.TB, port int, impl service2.AuthMetadataServiceServer) *authMetadataServer { +func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServiceServer) *authMetadataServer { return &authMetadataServer{ port: port, t: t, @@ -132,13 +128,13 @@ func Test_newAuthInterceptor(t *testing.T) { })) port := rand.IntnRange(10000, 60000) - m := &mocks2.AuthMetadataServiceServer{} - m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{ + m := &adminMocks.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), }, nil) - m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{ + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) s := newAuthMetadataServer(t, port, m) @@ -171,7 +167,7 @@ func Test_newAuthInterceptor(t *testing.T) { })) port := rand.IntnRange(10000, 60000) - m := &mocks2.AuthMetadataServiceServer{} + m := &adminMocks.AuthMetadataServiceServer{} s := newAuthMetadataServer(t, port, m) ctx := context.Background() assert.NoError(t, s.Start(ctx)) @@ -201,13 +197,13 @@ func Test_newAuthInterceptor(t *testing.T) { })) port := rand.IntnRange(10000, 60000) - m := &mocks2.AuthMetadataServiceServer{} - m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service2.OAuth2MetadataResponse{ + m := &adminMocks.AuthMetadataServiceServer{} + m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{ AuthorizationEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/authorize", port), TokenEndpoint: fmt.Sprintf("http://localhost:%d/oauth2/token", port), JwksUri: fmt.Sprintf("http://localhost:%d/oauth2/jwks", port), }, nil) - m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service2.PublicClientAuthConfigResponse{ + m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&service.PublicClientAuthConfigResponse{ Scopes: []string{"all"}, }, nil) @@ -237,8 +233,8 @@ func Test_newAuthInterceptor(t *testing.T) { func TestMaterializeCredentials(t *testing.T) { port := rand.IntnRange(10000, 60000) - t.Run("No public client config or oauth2 metadata endpoint lookup", func(t *testing.T) { - m := &mocks2.AuthMetadataServiceServer{} + t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) { + m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get public client config")) s := newAuthMetadataServer(t, port, m) @@ -256,12 +252,13 @@ func TestMaterializeCredentials(t *testing.T) { AuthType: AuthTypeClientSecret, TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port), Scopes: []string{"all"}, + Audience: "http://localhost:30081", AuthorizationHeader: "authorization", }, &mocks.TokenCache{}, f) assert.NoError(t, err) }) t.Run("Failed to fetch client metadata", func(t *testing.T) { - m := &mocks2.AuthMetadataServiceServer{} + m := &adminMocks.AuthMetadataServiceServer{} m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata")) failedPublicClientConfigLookup := errors.New("expected err") m.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(nil, failedPublicClientConfigLookup) diff --git a/clients/go/admin/config.go b/clients/go/admin/config.go index b776451c4..dd0652606 100644 --- a/clients/go/admin/config.go +++ b/clients/go/admin/config.go @@ -53,6 +53,8 @@ type Config struct { ClientSecretLocation string `json:"clientSecretLocation" pflag:",File containing the client secret"` ClientSecretEnvVar string `json:"clientSecretEnvVar" pflag:",Environment variable containing the client secret"` Scopes []string `json:"scopes" pflag:",List of scopes to request"` + UseAudienceFromAdmin bool `json:"useAudienceFromAdmin" pflag:",Use Audience configured from admins public endpoint config."` + Audience string `json:"audience" pflag:",Audience to use when initiating OAuth2 authorization requests."` // There are two ways to get the token URL. If the authorization server url is provided, the client will try to use RFC 8414 to // try to get the token URL. Or it can be specified directly through TokenURL config. diff --git a/clients/go/admin/config_flags.go b/clients/go/admin/config_flags.go index 6d65d30f8..53a6a4421 100755 --- a/clients/go/admin/config_flags.go +++ b/clients/go/admin/config_flags.go @@ -64,6 +64,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretLocation"), defaultConfig.ClientSecretLocation, "File containing the client secret") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "clientSecretEnvVar"), defaultConfig.ClientSecretEnvVar, "Environment variable containing the client secret") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "scopes"), defaultConfig.Scopes, "List of scopes to request") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "useAudienceFromAdmin"), defaultConfig.UseAudienceFromAdmin, "Use Audience configured from admins public endpoint config.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "audience"), defaultConfig.Audience, "Audience to use when initiating OAuth2 authorization requests.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationServerUrl"), defaultConfig.DeprecatedAuthorizationServerURL, "This is the URL to your IdP's authorization server. It'll default to Endpoint") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "tokenUrl"), defaultConfig.TokenURL, "OPTIONAL: Your IdP's token endpoint. It'll be discovered from flyte admin's OAuth Metadata endpoint if not provided.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "authorizationHeader"), defaultConfig.AuthorizationHeader, "Custom metadata header to pass JWT") diff --git a/clients/go/admin/config_flags_test.go b/clients/go/admin/config_flags_test.go index a44948d87..bdcec55f6 100755 --- a/clients/go/admin/config_flags_test.go +++ b/clients/go/admin/config_flags_test.go @@ -295,6 +295,34 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_useAudienceFromAdmin", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("useAudienceFromAdmin", testValue) + if vBool, err := cmdFlags.GetBool("useAudienceFromAdmin"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.UseAudienceFromAdmin) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_audience", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("audience", testValue) + if vString, err := cmdFlags.GetString("audience"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Audience) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_authorizationServerUrl", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/clients/go/admin/token_source_provider.go b/clients/go/admin/token_source_provider.go index 78a1952ab..c2a520d70 100644 --- a/clients/go/admin/token_source_provider.go +++ b/clients/go/admin/token_source_provider.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io/ioutil" + "net/url" "os" "strings" "sync" @@ -22,6 +23,10 @@ import ( "github.com/flyteorg/flytestdlib/logger" ) +const ( + audienceKey = "audience" +) + // TokenSourceProvider defines the interface needed to provide a TokenSource that is used to // create a client with authentication enabled. type TokenSourceProvider interface { @@ -46,15 +51,24 @@ func NewTokenSourceProvider(ctx context.Context, cfg *Config, tokenCache cache.T } scopes := cfg.Scopes - if len(scopes) == 0 { - clientMetadata, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) + audienceValue := cfg.Audience + + if len(scopes) == 0 || cfg.UseAudienceFromAdmin { + publicClientConfig, err := authClient.GetPublicClientConfig(ctx, &service.PublicClientAuthConfigRequest{}) if err != nil { return nil, fmt.Errorf("failed to fetch client metadata. Error: %v", err) } - scopes = clientMetadata.Scopes + // Update scopes from publicClientConfig + if len(scopes) == 0 { + scopes = publicClientConfig.Scopes + } + // Update audience from publicClientConfig + if cfg.UseAudienceFromAdmin { + audienceValue = publicClientConfig.Audience + } } - tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL) + tokenProvider, err = NewClientCredentialsTokenSourceProvider(ctx, cfg, scopes, tokenURL, audienceValue) if err != nil { return nil, err } @@ -152,7 +166,7 @@ type ClientCredentialsTokenSourceProvider struct { TokenRefreshWindow time.Duration } -func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string) (TokenSourceProvider, error) { +func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, scopes []string, tokenURL string, audience string) (TokenSourceProvider, error) { var secret string if len(cfg.ClientSecretEnvVar) > 0 { secret = os.Getenv(cfg.ClientSecretEnvVar) @@ -164,13 +178,19 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s } secret = string(secretBytes) } + endpointParams := url.Values{} + if len(audience) > 0 { + endpointParams = url.Values{audienceKey: {audience}} + } secret = strings.TrimSpace(secret) return ClientCredentialsTokenSourceProvider{ ccConfig: clientcredentials.Config{ - ClientID: cfg.ClientID, - ClientSecret: secret, - TokenURL: tokenURL, - Scopes: scopes}, + ClientID: cfg.ClientID, + ClientSecret: secret, + TokenURL: tokenURL, + Scopes: scopes, + EndpointParams: endpointParams, + }, TokenRefreshWindow: cfg.TokenRefreshWindow.Duration}, nil } diff --git a/clients/go/admin/token_source_test.go b/clients/go/admin/token_source_test.go index 0e247bfe8..9256e5e88 100644 --- a/clients/go/admin/token_source_test.go +++ b/clients/go/admin/token_source_test.go @@ -2,10 +2,16 @@ package admin import ( "context" + "net/url" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "golang.org/x/oauth2" + + tokenCacheMocks "github.com/flyteorg/flyteidl/clients/go/admin/cache/mocks" + adminMocks "github.com/flyteorg/flyteidl/clients/go/admin/mocks" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" ) type DummyTestTokenSource struct { @@ -25,3 +31,67 @@ func TestNewTokenSource(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Bearer abc", metadata["test"]) } + +func TestNewTokenSourceProvider(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + audienceCfg string + scopesCfg []string + useAudienceFromAdmin bool + clientConfigResponse service.PublicClientAuthConfigResponse + expectedAudience string + expectedScopes []string + expectedCallsPubEndpoint int + }{ + { + name: "audience from client config", + audienceCfg: "clientConfiguredAud", + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{}, + expectedAudience: "clientConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 0, + }, + { + name: "audience from public client response", + audienceCfg: "clientConfiguredAud", + useAudienceFromAdmin: true, + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, + expectedAudience: "AdminConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 1, + }, + + { + name: "audience from client with useAudience from admin false", + audienceCfg: "clientConfiguredAud", + useAudienceFromAdmin: false, + scopesCfg: []string{"all"}, + clientConfigResponse: service.PublicClientAuthConfigResponse{Audience: "AdminConfiguredAud", Scopes: []string{}}, + expectedAudience: "clientConfiguredAud", + expectedScopes: []string{"all"}, + expectedCallsPubEndpoint: 0, + }, + } + for _, test := range tests { + cfg := GetConfig(ctx) + tokenCache := &tokenCacheMocks.TokenCache{} + metadataClient := &adminMocks.AuthMetadataServiceClient{} + metadataClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(&service.OAuth2MetadataResponse{}, nil) + metadataClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(&test.clientConfigResponse, nil) + cfg.AuthType = AuthTypeClientSecret + cfg.Audience = test.audienceCfg + cfg.Scopes = test.scopesCfg + cfg.UseAudienceFromAdmin = test.useAudienceFromAdmin + flyteTokenSource, err := NewTokenSourceProvider(ctx, cfg, tokenCache, metadataClient) + assert.True(t, metadataClient.AssertNumberOfCalls(t, "GetPublicClientConfig", test.expectedCallsPubEndpoint)) + assert.NoError(t, err) + assert.NotNil(t, flyteTokenSource) + clientCredSourceProvider, ok := flyteTokenSource.(ClientCredentialsTokenSourceProvider) + assert.True(t, ok) + assert.Equal(t, test.expectedScopes, clientCredSourceProvider.ccConfig.Scopes) + assert.Equal(t, url.Values{audienceKey: {test.expectedAudience}}, clientCredSourceProvider.ccConfig.EndpointParams) + } +}