Skip to content

Commit

Permalink
auth: cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Laura Brehm <[email protected]>
  • Loading branch information
laurazard committed Jul 18, 2024
1 parent fe81d44 commit e29092a
Show file tree
Hide file tree
Showing 23 changed files with 447 additions and 155 deletions.
25 changes: 17 additions & 8 deletions cli/config/credentials/oauth_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ func NewOAuthStore(backingStore Store, manager oauth.Manager) Store {
const minimumTokenLifetime = 50 * time.Minute

// Get retrieves the credentials from the backing store, refreshing the
// access token if the retrieved token is valid for less than 50 minutes.
// If there are no credentials in the backing store, the device code flow
// is initiated with the tenant in order to log the user in and get
// access token if the stored credentials are valid for less than minimumTokenLifetime.
// If the credentials being retrieved are not for the official registry, they are
// returned as is. If the credentials retrieved do not parse as a token, they are
// also returned as is.
func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
if serverAddress != registry.IndexServer {
return o.backingStore.Get(serverAddress)
Expand All @@ -52,8 +53,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
return auth, nil
}

// if the access token is valid for less than 50 minutes, refresh it
// if the access token is valid for less than minimumTokenLifetime, refresh it
if tokenRes.RefreshToken != "" && tokenRes.Claims.Expiry.Time().Before(time.Now().Add(minimumTokenLifetime)) {
// todo(laurazard): should use a context with a timeout here?
refreshRes, err := o.manager.RefreshToken(context.TODO(), tokenRes.RefreshToken)
if err != nil {
return types.AuthConfig{}, err
Expand All @@ -74,8 +76,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) {
}, nil
}

// GetAll returns a map containing solely the auth config for the official
// registry, parsed from the backing store and refreshed if necessary.
// GetAll returns a map of all credentials in the backing store. If the backing
// store contains credentials for the official registry, these are refreshed/processed
// according to the same rules as Get.
func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) {
allAuths, err := o.backingStore.GetAll()
if err != nil {
Expand All @@ -98,8 +101,14 @@ func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) {
// tenant if running
func (o *oauthStore) Erase(serverAddress string) error {
if serverAddress == registry.IndexServer {
// todo(laurazard): should this log out from the tenant
_ = o.manager.Logout(context.TODO())
auth, err := o.backingStore.Get(registry.IndexServer)
if err != nil {
return err
}
if tokenRes, err := o.parseToken(auth.Password); err == nil {
// todo(laurazard): should use a context with a timeout here?
_ = o.manager.Logout(context.TODO(), tokenRes.RefreshToken)
}
}
return o.backingStore.Erase(serverAddress)
}
Expand Down
17 changes: 9 additions & 8 deletions cli/config/credentials/oauth_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,14 @@ func TestErase(t *testing.T) {
t.Run("official registry", func(t *testing.T) {
f := newStore(map[string]types.AuthConfig{
registry.IndexServer: {
Email: "[email protected]",
Email: "[email protected]",
Password: validNotExpiredToken + "..refresh-token",
},
})
var logoutCalled bool
var revokedToken string
manager := &testManager{
logout: func() error {
logoutCalled = true
logout: func(token string) error {
revokedToken = token
return nil
},
}
Expand All @@ -304,7 +305,7 @@ func TestErase(t *testing.T) {
assert.NilError(t, err)

assert.Check(t, is.Len(f.GetAuthConfigs(), 0))
assert.Check(t, logoutCalled)
assert.Equal(t, revokedToken, "refresh-token")
})

t.Run("different registry", func(t *testing.T) {
Expand Down Expand Up @@ -388,16 +389,16 @@ func TestStore(t *testing.T) {

type testManager struct {
loginDevice func() (oauth.TokenResult, error)
logout func() error
logout func(token string) error
refresh func(token string) (oauth.TokenResult, error)
}

func (m *testManager) LoginDevice(_ context.Context, _ io.Writer) (oauth.TokenResult, error) {
return m.loginDevice()
}

func (m *testManager) Logout(_ context.Context) error {
return m.logout()
func (m *testManager) Logout(_ context.Context, token string) error {
return m.logout(token)
}

func (m *testManager) RefreshToken(_ context.Context, token string) (oauth.TokenResult, error) {
Expand Down
64 changes: 48 additions & 16 deletions cli/internal/oauth/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"runtime"
"strings"
"time"

"github.com/docker/cli/cli/internal/oauth/util"
"github.com/docker/cli/cli/version"
)

type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
Refresh(ctx context.Context, token string) (TokenResponse, error)
LogoutURL() string
RevokeToken(ctx context.Context, refreshToken string) error
}

// API represents API interactions with Auth0.
Expand All @@ -28,8 +30,6 @@ type API struct {
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
// Client is the client that is used for calls.
Client util.Client
}

// TokenResponse represents the response of the /oauth/token route.
Expand All @@ -46,7 +46,9 @@ type TokenResponse struct {

var ErrTimeout = errors.New("timed out waiting for device token")

// GetDeviceCode returns device code authorization information from Auth0.
// GetDeviceCode initiates the device-code auth flow with the tenant.
// The state returned contains the device code that the user must use to
// authenticate, as well as the URL to visit, etc.
func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) {
data := url.Values{
"client_id": {a.ClientID},
Expand All @@ -55,7 +57,7 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e
}

deviceCodeURL := a.BaseURL + "/oauth/device/code"
resp, err := a.Client.PostForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
resp, err := postForm(ctx, deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
Expand All @@ -77,8 +79,10 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e
return
}

// WaitForDeviceToken polls to get tokens based on the device code set up. This
// only works in a device auth flow.
// WaitForDeviceToken polls the tenant to get access/refresh tokens for the user.
// This should be called after GetDeviceCode, and will block until the user has
// authenticated or we have reached the time limit for authenticating (based on
// the response from GetDeviceCode).
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ticker := time.NewTicker(state.IntervalDuration())
timeout := time.After(time.Duration(state.ExpiresIn) * time.Second)
Expand Down Expand Up @@ -119,7 +123,7 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse
}
oauthTokenURL := a.BaseURL + "/oauth/token"

resp, err := a.Client.PostForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
resp, err := postForm(ctx, oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return res, fmt.Errorf("failed to get code: %w", err)
}
Expand All @@ -130,7 +134,7 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse
return
}

// Refresh returns new tokens based on the refresh token.
// Refresh fetches new tokens using the refresh token.
func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err error) {
data := url.Values{
"grant_type": {"refresh_token"},
Expand All @@ -139,18 +143,46 @@ func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err
}

refreshURL := a.BaseURL + "/oauth/token"
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url
resp, err := http.PostForm(refreshURL, data)
resp, err := postForm(ctx, refreshURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
defer resp.Body.Close()

err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()

return
}

func (a API) LogoutURL() string {
return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID)
// RevokeToken revokes a refresh token with the tenant so that it can no longer
// be used to get new tokens.
func (a API) RevokeToken(ctx context.Context, refreshToken string) error {
data := url.Values{
"client_id": {a.ClientID},
"token": {refreshToken},
}

revokeURL := a.BaseURL + "/oauth/revoke"
resp, err := postForm(ctx, revokeURL, strings.NewReader(data.Encode()))
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.New("failed to revoke token")
}
return nil
}

func postForm(ctx context.Context, reqURL string, data io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, data)
if err != nil {
return nil, err
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
cliVersion := strings.ReplaceAll(version.Version, ".", "_")
req.Header.Set("User-Agent", fmt.Sprintf("docker-cli:%s:%s-%s", cliVersion, runtime.GOOS, runtime.GOARCH))

return http.DefaultClient.Do(req)
}
Loading

0 comments on commit e29092a

Please sign in to comment.