-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for filesystem implementation of token cache
Signed-off-by: Nicholas Tate <[email protected]>
- Loading branch information
Showing
6 changed files
with
303 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
package pkce | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
|
||
"golang.org/x/oauth2" | ||
|
||
b64 "encoding/base64" | ||
"encoding/json" | ||
|
||
f "github.com/flyteorg/flyte/flytectl/pkg/filesystemutils" | ||
) | ||
|
||
// tokenCacheFilesystemProvider wraps the logic to save and retrieve tokens from the fs. | ||
type tokenCacheFilesystemProvider struct { | ||
ServiceUser string | ||
|
||
// credentialsFile is the path to the file where the credentials are stored. This is | ||
// typically $HOME/.flyte/credentials.json but embedded as a private field for tests. | ||
credentialsFile string | ||
} | ||
|
||
func NewtokenCacheFilesystemProvider(serviceUser string) *tokenCacheFilesystemProvider { | ||
return &tokenCacheFilesystemProvider{ | ||
ServiceUser: serviceUser, | ||
credentialsFile: f.FilePathJoin(f.UserHomeDir(), ".flyte", "credentials.json"), | ||
} | ||
} | ||
|
||
type credentials map[string]*oauth2.Token | ||
|
||
func (c credentials) MarshalJSON() ([]byte, error) { | ||
m := make(map[string]string) | ||
for k, v := range c { | ||
b, err := json.Marshal(v) | ||
if err != nil { | ||
return nil, err | ||
} | ||
m[k] = b64.StdEncoding.EncodeToString(b) | ||
} | ||
return json.Marshal(m) | ||
} | ||
|
||
func (c credentials) UnmarshalJSON(b []byte) error { | ||
m := make(map[string]string) | ||
if err := json.Unmarshal(b, &m); err != nil { | ||
return err | ||
} | ||
for k, v := range m { | ||
s, err := b64.StdEncoding.DecodeString(v) | ||
if err != nil { | ||
return err | ||
} | ||
tk := &oauth2.Token{} | ||
if err = json.Unmarshal(s, tk); err != nil { | ||
return err | ||
} | ||
c[k] = tk | ||
} | ||
return nil | ||
} | ||
|
||
func (t tokenCacheFilesystemProvider) SaveToken(token *oauth2.Token) error { | ||
if token.AccessToken == "" { | ||
return fmt.Errorf("cannot save empty token with expiration %v", token.Expiry) | ||
} | ||
|
||
dir := filepath.Dir(t.credentialsFile) | ||
if err := os.MkdirAll(dir, 0700); err != nil { | ||
return fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error()) | ||
} | ||
|
||
creds, err := t.getExistingCredentials() | ||
if err != nil { | ||
return err | ||
} | ||
creds[t.ServiceUser] = token | ||
|
||
tmp, err := os.CreateTemp("", "flytectl") | ||
if err != nil { | ||
return fmt.Errorf("creating tmp file for credentials update: %s", err.Error()) | ||
} | ||
defer os.Remove(tmp.Name()) | ||
|
||
b, err := json.Marshal(creds) | ||
if err != nil { | ||
return fmt.Errorf("marshalling credentials: %s", err.Error()) | ||
} | ||
if _, err := tmp.Write(b); err != nil { | ||
return fmt.Errorf("writing updated credentials to tmp file: %s", err.Error()) | ||
} | ||
|
||
if err = os.Rename(tmp.Name(), t.credentialsFile); err != nil { | ||
return fmt.Errorf("updating credentials via tmp file rename: %s", err.Error()) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (t tokenCacheFilesystemProvider) GetToken() (*oauth2.Token, error) { | ||
creds, err := t.getExistingCredentials() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if token, ok := creds[t.ServiceUser]; ok { | ||
return token, nil | ||
} | ||
|
||
return nil, errors.New("token does not exist") | ||
} | ||
|
||
func (t tokenCacheFilesystemProvider) getExistingCredentials() (credentials, error) { | ||
dir := filepath.Dir(t.credentialsFile) | ||
if err := os.MkdirAll(dir, 0700); err != nil { | ||
return nil, fmt.Errorf("creating base directory (%s) for credentials: %s", dir, err.Error()) | ||
} | ||
|
||
creds := credentials{} | ||
if _, err := os.Stat(t.credentialsFile); errors.Is(err, os.ErrNotExist) { | ||
return creds, nil | ||
} | ||
|
||
b, err := os.ReadFile(t.credentialsFile) | ||
if err != nil { | ||
return nil, fmt.Errorf("reading existing credentials: %s", err.Error()) | ||
} | ||
|
||
if err = json.Unmarshal(b, &creds); err != nil { | ||
return nil, fmt.Errorf("unmarshalling credentials: %s", err.Error()) | ||
} | ||
|
||
return creds, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package pkce | ||
|
||
import ( | ||
"encoding/json" | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
func TestSaveAndGetTokenFS(t *testing.T) { | ||
setup := func(t *testing.T) tokenCacheFilesystemProvider { | ||
t.Helper() | ||
// Everything inside the directory is automatically cleaned up by the test runner. | ||
dir := t.TempDir() | ||
tokenCacheProvider := tokenCacheFilesystemProvider{ | ||
ServiceUser: "testServiceUser", | ||
credentialsFile: filepath.Join(dir, "credentials.json"), | ||
} | ||
return tokenCacheProvider | ||
} | ||
|
||
t.Run("Valid Save/Get Token", func(t *testing.T) { | ||
tokenCacheProvider := setup(t) | ||
|
||
plan, err := os.ReadFile("testdata/token.json") | ||
require.NoError(t, err) | ||
|
||
var tokenData oauth2.Token | ||
err = json.Unmarshal(plan, &tokenData) | ||
require.NoError(t, err) | ||
|
||
err = tokenCacheProvider.SaveToken(&tokenData) | ||
require.NoError(t, err) | ||
|
||
var savedToken *oauth2.Token | ||
savedToken, err = tokenCacheProvider.GetToken() | ||
require.NoError(t, err) | ||
|
||
assert.NotNil(t, savedToken) | ||
assert.Equal(t, tokenData.AccessToken, savedToken.AccessToken) | ||
assert.Equal(t, tokenData.TokenType, savedToken.TokenType) | ||
assert.Equal(t, tokenData.Expiry, savedToken.Expiry) | ||
}) | ||
|
||
t.Run("Empty access token Save", func(t *testing.T) { | ||
tokenCacheProvider := setup(t) | ||
|
||
plan, err := os.ReadFile("testdata/empty_access_token.json") | ||
require.NoError(t, err) | ||
|
||
var tokenData oauth2.Token | ||
err = json.Unmarshal(plan, &tokenData) | ||
require.NoError(t, err) | ||
|
||
err = tokenCacheProvider.SaveToken(&tokenData) | ||
assert.Error(t, err) | ||
}) | ||
|
||
t.Run("Different service name", func(t *testing.T) { | ||
tokenCacheProvider := setup(t) | ||
|
||
plan, err := os.ReadFile("testdata/token.json") | ||
require.NoError(t, err) | ||
|
||
var tokenData oauth2.Token | ||
err = json.Unmarshal(plan, &tokenData) | ||
require.NoError(t, err) | ||
|
||
err = tokenCacheProvider.SaveToken(&tokenData) | ||
require.NoError(t, err) | ||
|
||
tokenCacheProvider2 := setup(t) | ||
|
||
var savedToken *oauth2.Token | ||
savedToken, err = tokenCacheProvider2.GetToken() | ||
assert.Error(t, err) | ||
assert.Nil(t, savedToken) | ||
|
||
err = tokenCacheProvider2.SaveToken(&tokenData) | ||
require.NoError(t, err) | ||
|
||
// new token exists | ||
savedToken, err = tokenCacheProvider2.GetToken() | ||
require.NoError(t, err) | ||
assert.NotNil(t, savedToken) | ||
|
||
// token for different service name still exists | ||
savedToken, err = tokenCacheProvider.GetToken() | ||
require.NoError(t, err) | ||
assert.NotNil(t, savedToken) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters