Skip to content

Commit

Permalink
add support for filesystem implementation of token cache
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Tate <[email protected]>
  • Loading branch information
nicktate committed May 24, 2024
1 parent 470621e commit 35b3b19
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 10 deletions.
13 changes: 8 additions & 5 deletions flytectl/cmd/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,28 @@ import (
"fmt"
"strings"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flytestdlib/config"

"github.com/flyteorg/flyte/flytectl/pkg/printer"
)

var (
defaultConfig = &Config{
Output: printer.OutputFormatTABLE.String(),
Output: printer.OutputFormatTABLE.String(),
TokenCacheType: cache.TokenCacheTypeKeyring,
}

section = config.MustRegisterSection("root", defaultConfig)
)

// Config hold configuration for flytectl flag
type Config struct {
Project string `json:"project" pflag:",Specifies the project to work on."`
Domain string `json:"domain" pflag:",Specifies the domain to work on."`
Output string `json:"output" pflag:",Specifies the output type."`
Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."`
Project string `json:"project" pflag:",Specifies the project to work on."`
Domain string `json:"domain" pflag:",Specifies the domain to work on."`
Output string `json:"output" pflag:",Specifies the output type."`
Interactive bool `json:"interactive" pflag:",Set this to trigger bubbletea interface."`
TokenCacheType cache.TokenCacheType `json:"token_cache_type" pflag:",Specifices the token cache type to use for fetching / saving auth tokens."`
}

// OutputFormat will return output format
Expand Down
20 changes: 16 additions & 4 deletions flytectl/cmd/core/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/flyteorg/flyte/flytectl/cmd/config"
"github.com/flyteorg/flyte/flytectl/pkg/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"

"github.com/spf13/cobra"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -70,13 +71,24 @@ func generateCommandFunc(cmdEntry CommandEntry) func(cmd *cobra.Command, args []
return cmdEntry.CmdFunc(ctx, args, CommandContext{})
}

var tokenCache cache.TokenCache
svcUser := fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser)
switch config.GetConfig().TokenCacheType {
case cache.TokenCacheTypeFilesystem:
tokenCache = pkce.NewtokenCacheFilesystemProvider(svcUser)
case cache.TokenCacheTypeKeyring:
fallthrough
default:
tokenCache = pkce.TokenCacheKeyringProvider{
ServiceUser: svcUser,
ServiceName: pkce.KeyRingServiceName,
}
}

cmdCtx := NewCommandContextNoClient(cmd.OutOrStdout())
if !cmdEntry.DisableFlyteClient {
clientSet, err := admin.ClientSetBuilder().WithConfig(admin.GetConfig(ctx)).
WithTokenCache(pkce.TokenCacheKeyringProvider{
ServiceUser: fmt.Sprintf("%s:%s", adminCfg.Endpoint.String(), pkce.KeyRingServiceUser),
ServiceName: pkce.KeyRingServiceName,
}).Build(ctx)
WithTokenCache(tokenCache).Build(ctx)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions flytectl/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/flyteorg/flyte/flytectl/cmd/version"
f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils"
"github.com/flyteorg/flyte/flytectl/pkg/printer"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
stdConfig "github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/config/viper"

Expand Down Expand Up @@ -57,6 +58,7 @@ func newRootCmd() *cobra.Command {
rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Domain), "domain", "d", "", "Specifies the Flyte project's domain.")
rootCmd.PersistentFlags().StringVarP(&(config.GetConfig().Output), "output", "o", printer.OutputFormatTABLE.String(), fmt.Sprintf("Specifies the output type - supported formats %s. NOTE: dot, doturl are only supported for Workflow", printer.OutputFormats()))
rootCmd.PersistentFlags().BoolVarP(&(config.GetConfig().Interactive), "interactive", "i", false, "Set this flag to use an interactive CLI")
rootCmd.PersistentFlags().Var(&(config.GetConfig().TokenCacheType), "token-cache-type", fmt.Sprintf("Type of token cache to use (available options are %s)", cache.AllTokenCacheTypes))

rootCmd.AddCommand(get.CreateGetCommand())
compileCmd := compile.CreateCompileCommand()
Expand Down
137 changes: 137 additions & 0 deletions flytectl/pkg/pkce/token_cache_filesystem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package pkce

import (
"errors"
"fmt"
"os"
"path/filepath"

"golang.org/x/oauth2"

b64 "encoding/base64"
"encoding/json"

f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils"
)

// tokenCacheFilesystemProvider wraps the logic to save and retrieve tokens from the fs.
type tokenCacheFilesystemProvider struct {
ServiceUser string

// credentialsFile is the path to the file where the credentials are stored. This is
// typically $HOME/.flyte/credentials.json but embedded as a private field for tests.
credentialsFile string
}

func NewtokenCacheFilesystemProvider(serviceUser string) *tokenCacheFilesystemProvider {
return &tokenCacheFilesystemProvider{
ServiceUser: serviceUser,
credentialsFile: f.FilePathJoin(f.UserHomeDir(), ".flyte", "credentials.json"),
}
}

type credentials map[string]*oauth2.Token

func (c credentials) MarshalJSON() ([]byte, error) {
m := make(map[string]string)
for k, v := range c {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
m[k] = b64.StdEncoding.EncodeToString(b)
}
return json.Marshal(m)
}

func (c credentials) UnmarshalJSON(b []byte) error {
m := make(map[string]string)
if err := json.Unmarshal(b, &m); err != nil {
return err
}
for k, v := range m {
s, err := b64.StdEncoding.DecodeString(v)
if err != nil {
return err
}
tk := &oauth2.Token{}
if err = json.Unmarshal(s, tk); err != nil {
return err
}
c[k] = tk
}
return nil
}

func (t tokenCacheFilesystemProvider) SaveToken(token *oauth2.Token) error {
if token.AccessToken == "" {
return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry)
}

dir := filepath.Dir(t.credentialsFile)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error())
}

creds, err := t.getExistingCredentials()
if err != nil {
return err
}
creds[t.ServiceUser] = token

tmp, err := os.CreateTemp("", "flytectl")
if err != nil {
return fmt.Errorf("creating tmp file for credentials update: %s", err.Error())
}
defer os.Remove(tmp.Name())

b, err := json.Marshal(creds)
if err != nil {
return fmt.Errorf("marshalling credentials: %s", err.Error())
}
if _, err := tmp.Write(b); err != nil {
return fmt.Errorf("writing updated credentials to tmp file: %s", err.Error())
}

if err = os.Rename(tmp.Name(), t.credentialsFile); err != nil {
return fmt.Errorf("updating credentials via tmp file rename: %s", err.Error())
}

return nil
}

func (t tokenCacheFilesystemProvider) GetToken() (*oauth2.Token, error) {
creds, err := t.getExistingCredentials()
if err != nil {
return nil, err
}

if token, ok := creds[t.ServiceUser]; ok {
return token, nil
}

return nil, errors.New("token does not exist")
}

func (t tokenCacheFilesystemProvider) getExistingCredentials() (credentials, error) {
dir := filepath.Dir(t.credentialsFile)
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error())
}

creds := credentials{}
if _, err := os.Stat(t.credentialsFile); errors.Is(err, os.ErrNotExist) {
return creds, nil
}

b, err := os.ReadFile(t.credentialsFile)
if err != nil {
return nil, fmt.Errorf("reading existing credentials: %s", err.Error())
}

if err = json.Unmarshal(b, &creds); err != nil {
return nil, fmt.Errorf("unmarshalling credentials: %s", err.Error())
}

return creds, nil
}
96 changes: 96 additions & 0 deletions flytectl/pkg/pkce/token_cache_filesystem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package pkce

import (
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestSaveAndGetTokenFS(t *testing.T) {
setup := func(t *testing.T) tokenCacheFilesystemProvider {
t.Helper()
// Everything inside the directory is automatically cleaned up by the test runner.
dir := t.TempDir()
tokenCacheProvider := tokenCacheFilesystemProvider{
ServiceUser: "testServiceUser",
credentialsFile: filepath.Join(dir, "credentials.json"),
}
return tokenCacheProvider
}

t.Run("Valid Save/Get Token", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
require.NoError(t, err)

var savedToken *oauth2.Token
savedToken, err = tokenCacheProvider.GetToken()
require.NoError(t, err)

assert.NotNil(t, savedToken)
assert.Equal(t, tokenData.AccessToken, savedToken.AccessToken)
assert.Equal(t, tokenData.TokenType, savedToken.TokenType)
assert.Equal(t, tokenData.Expiry, savedToken.Expiry)
})

t.Run("Empty access token Save", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/empty_access_token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
assert.Error(t, err)
})

t.Run("Different service name", func(t *testing.T) {
tokenCacheProvider := setup(t)

plan, err := os.ReadFile("testdata/token.json")
require.NoError(t, err)

var tokenData oauth2.Token
err = json.Unmarshal(plan, &tokenData)
require.NoError(t, err)

err = tokenCacheProvider.SaveToken(&tokenData)
require.NoError(t, err)

tokenCacheProvider2 := setup(t)

var savedToken *oauth2.Token
savedToken, err = tokenCacheProvider2.GetToken()
assert.Error(t, err)
assert.Nil(t, savedToken)

err = tokenCacheProvider2.SaveToken(&tokenData)
require.NoError(t, err)

// new token exists
savedToken, err = tokenCacheProvider2.GetToken()
require.NoError(t, err)
assert.NotNil(t, savedToken)

// token for different service name still exists
savedToken, err = tokenCacheProvider.GetToken()
require.NoError(t, err)
assert.NotNil(t, savedToken)
})
}
45 changes: 44 additions & 1 deletion flyteidl/clients/go/admin/cache/token_cache.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
package cache

import "golang.org/x/oauth2"
import (
"fmt"
"slices"
"strings"

"golang.org/x/oauth2"
)

//go:generate mockery -all -case=underscore

// TokenCacheType defines the type of token cache implementation.
type TokenCacheType string

const (
// TokenCacheTypeKeyring represents the token cache implementation using the OS's keyring.
TokenCacheTypeKeyring TokenCacheType = "keyring"
// TokenCacheTypeInMemory represents the token cache implementation using an in-memory cache.
TokenCacheTypeInMemory = "inmemory"
// TokenCacheTypeFilesystem represents the token cache implementation using the local filesystem.
TokenCacheTypeFilesystem = "filesystem"
)

var AllTokenCacheTypes = []TokenCacheType{TokenCacheTypeKeyring, TokenCacheTypeInMemory, TokenCacheTypeFilesystem}

// String implements pflag.Value interface.
func (t *TokenCacheType) String() string {
if t == nil {
return ""
}
return string(*t)
}

// Set implements pflag.Value interface.
func (t *TokenCacheType) Set(value string) error {
if slices.Contains(AllTokenCacheTypes, TokenCacheType(strings.ToLower(value))) {
*t = TokenCacheType(value)
return nil
}

return fmt.Errorf("%s is an unrecognized token cache type (supported types %v)", value, AllTokenCacheTypes)
}

// Type implements pflag.Value interface.
func (t *TokenCacheType) Type() string {
return "token-cache-type"
}

// TokenCache defines the interface needed to cache and retrieve oauth tokens.
type TokenCache interface {
// SaveToken saves the token securely to cache.
Expand Down

0 comments on commit 35b3b19

Please sign in to comment.