Skip to content

Commit

Permalink
auth: cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Laura Brehm <[email protected]>
  • Loading branch information
laurazard committed Jul 17, 2024
1 parent fe81d44 commit 0be6989
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 126 deletions.
25 changes: 17 additions & 8 deletions cli/config/credentials/oauth_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ func NewOAuthStore(backingStore Store, manager oauth.Manager) Store {
const minimumTokenLifetime = 50 * time.Minute

// Get retrieves the credentials from the backing store, refreshing the
// access token if the retrieved token is valid for less than 50 minutes.
// If there are no credentials in the backing store, the device code flow
// is initiated with the tenant in order to log the user in and get
// access token if the stored credentials are valid for less than minimumTokenLifetime.
// If the credentials being retrieved are not for the official registry, they are
// returned as is. If the credentials retrieved do not parse as a token, they are
// also returned as is.
func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
if serverAddress != registry.IndexServer {
return o.backingStore.Get(serverAddress)
Expand All @@ -52,8 +53,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
return auth, nil
}

// if the access token is valid for less than 50 minutes, refresh it
// if the access token is valid for less than minimumTokenLifetime, refresh it
if tokenRes.RefreshToken != "" && tokenRes.Claims.Expiry.Time().Before(time.Now().Add(minimumTokenLifetime)) {
// todo(laurazard): should use a context with a timeout here?
refreshRes, err := o.manager.RefreshToken(context.TODO(), tokenRes.RefreshToken)
if err != nil {
return types.AuthConfig{}, err
Expand All @@ -74,8 +76,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
}, nil
}

// GetAll returns a map containing solely the auth config for the official
// registry, parsed from the backing store and refreshed if necessary.
// GetAll returns a map of all credentials in the backing store. If the backing
// store contains credentials for the official registry, these are refreshed/processed
// according to the same rules as Get.
func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) {
allAuths, err := o.backingStore.GetAll()
if err != nil {
Expand All @@ -98,8 +101,14 @@ func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) {
// tenant if running
func (o *oauthStore) Erase(serverAddress string) error {
if serverAddress == registry.IndexServer {
// todo(laurazard): should this log out from the tenant
_ = o.manager.Logout(context.TODO())
auth, err := o.backingStore.Get(registry.IndexServer)
if err != nil {
return err
}
if tokenRes, err := o.parseToken(auth.Password); err == nil {
// todo(laurazard): should use a context with a timeout here?
_ = o.manager.Logout(context.TODO(), tokenRes.RefreshToken)
}
}
return o.backingStore.Erase(serverAddress)
}
Expand Down
17 changes: 9 additions & 8 deletions cli/config/credentials/oauth_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,14 @@ func TestErase(t *testing.T) {
t.Run("official registry", func(t *testing.T) {
f := newStore(map[string]types.AuthConfig{
registry.IndexServer: {
Email: "[email protected]",
Email: "[email protected]",
Password: validNotExpiredToken + "..refresh-token",
},
})
var logoutCalled bool
var revokedToken string
manager := &testManager{
logout: func() error {
logoutCalled = true
logout: func(token string) error {
revokedToken = token
return nil
},
}
Expand All @@ -304,7 +305,7 @@ func TestErase(t *testing.T) {
assert.NilError(t, err)

assert.Check(t, is.Len(f.GetAuthConfigs(), 0))
assert.Check(t, logoutCalled)
assert.Equal(t, revokedToken, "refresh-token")
})

t.Run("different registry", func(t *testing.T) {
Expand Down Expand Up @@ -388,16 +389,16 @@ func TestStore(t *testing.T) {

type testManager struct {
loginDevice func() (oauth.TokenResult, error)
logout func() error
logout func(token string) error
refresh func(token string) (oauth.TokenResult, error)
}

func (m *testManager) LoginDevice(_ context.Context, _ io.Writer) (oauth.TokenResult, error) {
return m.loginDevice()
}

func (m *testManager) Logout(_ context.Context) error {
return m.logout()
func (m *testManager) Logout(_ context.Context, token string) error {
return m.logout(token)
}

func (m *testManager) RefreshToken(_ context.Context, token string) (oauth.TokenResult, error) {
Expand Down
30 changes: 22 additions & 8 deletions cli/internal/oauth/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
Refresh(ctx context.Context, token string) (TokenResponse, error)
LogoutURL() string
Revoke(ctx context.Context, refreshToken string) error
}

// API represents API interactions with Auth0.
Expand All @@ -28,8 +28,6 @@ type API struct {
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
// Client is the client that is used for calls.
Client util.Client
}

// TokenResponse represents the response of the /oauth/token route.
Expand All @@ -55,7 +53,7 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e
}

deviceCodeURL := a.BaseURL + "/oauth/device/code"
resp, err := a.Client.PostForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
resp, err := util.PostForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
Expand Down Expand Up @@ -119,7 +117,7 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse
}
oauthTokenURL := a.BaseURL + "/oauth/token"

resp, err := a.Client.PostForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
resp, err := util.PostForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return res, fmt.Errorf("failed to get code: %w", err)
}
Expand All @@ -140,7 +138,7 @@ func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err

refreshURL := a.BaseURL + "/oauth/token"
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url
resp, err := http.PostForm(refreshURL, data)
resp, err := util.PostForm(ctx, refreshURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
Expand All @@ -151,6 +149,22 @@ func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err
return
}

func (a API) LogoutURL() string {
return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID)
func (a API) Revoke(ctx context.Context, refreshToken string) error {
data := url.Values{
"client_id": {a.ClientID},
"token": {refreshToken},
}

revokeURL := a.BaseURL + "/oauth/revoke"
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url
resp, err := util.PostForm(ctx, revokeURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}

if resp.StatusCode != http.StatusOK {
return errors.New("failed to revoke token")
}

return nil
}
122 changes: 114 additions & 8 deletions cli/internal/oauth/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"
"time"

"github.com/docker/cli/cli/internal/oauth/util"
"gotest.tools/v3/assert"
)

Expand Down Expand Up @@ -38,7 +37,6 @@ func TestGetDeviceCode(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

state, err := api.GetDeviceCode(context.Background(), "anAudience")
Expand Down Expand Up @@ -66,7 +64,6 @@ func TestGetDeviceCode(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "bad_audience")
Expand All @@ -83,7 +80,6 @@ func TestGetDeviceCode(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "anAudience")
Expand All @@ -101,7 +97,6 @@ func TestGetDeviceCode(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -155,7 +150,6 @@ func TestWaitForDeviceToken(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
Expand Down Expand Up @@ -189,7 +183,6 @@ func TestWaitForDeviceToken(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
Expand Down Expand Up @@ -217,7 +210,6 @@ func TestWaitForDeviceToken(t *testing.T) {
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
Expand All @@ -236,3 +228,117 @@ func TestWaitForDeviceToken(t *testing.T) {
assert.ErrorContains(t, err, "context canceled")
})
}

func TestRefresh(t *testing.T) {
t.Run("success", func(t *testing.T) {
expectedToken := TokenResponse{
AccessToken: "a-real-token",
IDToken: "",
RefreshToken: "the-refresh-token",
Scope: "",
ExpiresIn: 3600,
TokenType: "",
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("refresh_token"), "v1.a-refresh-token")
assert.Equal(t, r.FormValue("grant_type"), "refresh_token")

jsonState, err := json.Marshal(expectedToken)
assert.NilError(t, err)
w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}

token, err := api.Refresh(context.Background(), "v1.a-refresh-token")
assert.NilError(t, err)

assert.DeepEqual(t, token, expectedToken)
})

t.Run("canceled context", func(t *testing.T) {
expectedToken := TokenResponse{
AccessToken: "a-real-token",
IDToken: "",
RefreshToken: "the-refresh-token",
Scope: "",
ExpiresIn: 3600,
TokenType: "",
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("refresh_token"), "v1.a-refresh-token")
assert.Equal(t, r.FormValue("grant_type"), "refresh_token")

jsonState, err := json.Marshal(expectedToken)
assert.NilError(t, err)
w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()

_, err := api.Refresh(ctx, "v1.a-refresh-token")

assert.ErrorContains(t, err, "context canceled")
})
}

func TestRevoke(t *testing.T) {
t.Run("success", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")

w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}

err := api.Revoke(context.Background(), "v1.a-refresh-token")
assert.NilError(t, err)
})

t.Run("canceled context", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/revoke", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token")

w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()

err := api.Revoke(ctx, "v1.a-refresh-token")

assert.ErrorContains(t, err, "context canceled")
})
}
Loading

0 comments on commit 0be6989

Please sign in to comment.