diff --git a/cli/config/credentials/oauth_store.go b/cli/config/credentials/oauth_store.go index 399220c63e50..dc14ad2a8256 100644 --- a/cli/config/credentials/oauth_store.go +++ b/cli/config/credentials/oauth_store.go @@ -29,9 +29,10 @@ func NewOAuthStore(backingStore Store, manager oauth.Manager) Store { const minimumTokenLifetime = 50 * time.Minute // Get retrieves the credentials from the backing store, refreshing the -// access token if the retrieved token is valid for less than 50 minutes. -// If there are no credentials in the backing store, the device code flow -// is initiated with the tenant in order to log the user in and get +// access token if the stored credentials are valid for less than minimumTokenLifetime. +// If the credentials being retrieved are not for the official registry, they are +// returned as is. If the credentials retrieved do not parse as a token, they are +// also returned as is. func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) { if serverAddress != registry.IndexServer { return o.backingStore.Get(serverAddress) @@ -52,8 +53,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) { return auth, nil } - // if the access token is valid for less than 50 minutes, refresh it + // if the access token is valid for less than minimumTokenLifetime, refresh it if tokenRes.RefreshToken != "" && tokenRes.Claims.Expiry.Time().Before(time.Now().Add(minimumTokenLifetime)) { + // todo(laurazard): should use a context with a timeout here? refreshRes, err := o.manager.RefreshToken(context.TODO(), tokenRes.RefreshToken) if err != nil { return types.AuthConfig{}, err @@ -74,8 +76,9 @@ func (o *oauthStore) Get(serverAddress string) (types.AuthConfig, error) { }, nil } -// GetAll returns a map containing solely the auth config for the official -// registry, parsed from the backing store and refreshed if necessary. +// GetAll returns a map of all credentials in the backing store. If the backing +// store contains credentials for the official registry, these are refreshed/processed +// according to the same rules as Get. func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) { allAuths, err := o.backingStore.GetAll() if err != nil { @@ -98,8 +101,14 @@ func (o *oauthStore) GetAll() (map[string]types.AuthConfig, error) { // tenant if running func (o *oauthStore) Erase(serverAddress string) error { if serverAddress == registry.IndexServer { - // todo(laurazard): should this log out from the tenant - _ = o.manager.Logout(context.TODO()) + auth, err := o.backingStore.Get(registry.IndexServer) + if err != nil { + return err + } + if tokenRes, err := o.parseToken(auth.Password); err == nil { + // todo(laurazard): should use a context with a timeout here? + _ = o.manager.Logout(context.TODO(), tokenRes.RefreshToken) + } } return o.backingStore.Erase(serverAddress) } diff --git a/cli/config/credentials/oauth_store_test.go b/cli/config/credentials/oauth_store_test.go index 19287f8edc9b..cc24fac5b30a 100644 --- a/cli/config/credentials/oauth_store_test.go +++ b/cli/config/credentials/oauth_store_test.go @@ -286,13 +286,14 @@ func TestErase(t *testing.T) { t.Run("official registry", func(t *testing.T) { f := newStore(map[string]types.AuthConfig{ registry.IndexServer: { - Email: "foo@example.com", + Email: "foo@example.com", + Password: validNotExpiredToken + "..refresh-token", }, }) - var logoutCalled bool + var revokedToken string manager := &testManager{ - logout: func() error { - logoutCalled = true + logout: func(token string) error { + revokedToken = token return nil }, } @@ -304,7 +305,7 @@ func TestErase(t *testing.T) { assert.NilError(t, err) assert.Check(t, is.Len(f.GetAuthConfigs(), 0)) - assert.Check(t, logoutCalled) + assert.Equal(t, revokedToken, "refresh-token") }) t.Run("different registry", func(t *testing.T) { @@ -388,7 +389,7 @@ func TestStore(t *testing.T) { type testManager struct { loginDevice func() (oauth.TokenResult, error) - logout func() error + logout func(token string) error refresh func(token string) (oauth.TokenResult, error) } @@ -396,8 +397,8 @@ func (m *testManager) LoginDevice(_ context.Context, _ io.Writer) (oauth.TokenRe return m.loginDevice() } -func (m *testManager) Logout(_ context.Context) error { - return m.logout() +func (m *testManager) Logout(_ context.Context, token string) error { + return m.logout(token) } func (m *testManager) RefreshToken(_ context.Context, token string) (oauth.TokenResult, error) { diff --git a/cli/internal/oauth/api/api.go b/cli/internal/oauth/api/api.go index 539b68c21c3c..53c12614331f 100644 --- a/cli/internal/oauth/api/api.go +++ b/cli/internal/oauth/api/api.go @@ -17,7 +17,7 @@ type OAuthAPI interface { GetDeviceCode(ctx context.Context, audience string) (State, error) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) Refresh(ctx context.Context, token string) (TokenResponse, error) - LogoutURL() string + Revoke(ctx context.Context, refreshToken string) error } // API represents API interactions with Auth0. @@ -28,8 +28,6 @@ type API struct { ClientID string // Scopes are the scopes that are requested during the device auth flow. Scopes []string - // Client is the client that is used for calls. - Client util.Client } // TokenResponse represents the response of the /oauth/token route. @@ -55,7 +53,7 @@ func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, e } deviceCodeURL := a.BaseURL + "/oauth/device/code" - resp, err := a.Client.PostForm(ctx, deviceCodeURL, strings.NewReader(data.Encode())) + resp, err := util.PostForm(ctx, deviceCodeURL, strings.NewReader(data.Encode())) if err != nil { return } @@ -119,7 +117,7 @@ func (a API) getDeviceToken(ctx context.Context, state State) (res TokenResponse } oauthTokenURL := a.BaseURL + "/oauth/token" - resp, err := a.Client.PostForm(ctx, oauthTokenURL, strings.NewReader(data.Encode())) + resp, err := util.PostForm(ctx, oauthTokenURL, strings.NewReader(data.Encode())) if err != nil { return res, fmt.Errorf("failed to get code: %w", err) } @@ -140,7 +138,7 @@ func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err refreshURL := a.BaseURL + "/oauth/token" //nolint:gosec // Ignore G107: Potential HTTP request made with variable url - resp, err := http.PostForm(refreshURL, data) + resp, err := util.PostForm(ctx, refreshURL, strings.NewReader(data.Encode())) if err != nil { return } @@ -151,6 +149,22 @@ func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err return } -func (a API) LogoutURL() string { - return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID) +func (a API) Revoke(ctx context.Context, refreshToken string) error { + data := url.Values{ + "client_id": {a.ClientID}, + "token": {refreshToken}, + } + + revokeURL := a.BaseURL + "/oauth/revoke" + //nolint:gosec // Ignore G107: Potential HTTP request made with variable url + resp, err := util.PostForm(ctx, revokeURL, strings.NewReader(data.Encode())) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return errors.New("failed to revoke token") + } + + return nil } diff --git a/cli/internal/oauth/api/api_test.go b/cli/internal/oauth/api/api_test.go index 29a62e8da30e..fc463ac95efe 100644 --- a/cli/internal/oauth/api/api_test.go +++ b/cli/internal/oauth/api/api_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/docker/cli/cli/internal/oauth/util" "gotest.tools/v3/assert" ) @@ -38,7 +37,6 @@ func TestGetDeviceCode(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } state, err := api.GetDeviceCode(context.Background(), "anAudience") @@ -66,7 +64,6 @@ func TestGetDeviceCode(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } _, err := api.GetDeviceCode(context.Background(), "bad_audience") @@ -83,7 +80,6 @@ func TestGetDeviceCode(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } _, err := api.GetDeviceCode(context.Background(), "anAudience") @@ -101,7 +97,6 @@ func TestGetDeviceCode(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } ctx, cancel := context.WithCancel(context.Background()) @@ -155,7 +150,6 @@ func TestWaitForDeviceToken(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } state := State{ DeviceCode: "aDeviceCode", @@ -189,7 +183,6 @@ func TestWaitForDeviceToken(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } state := State{ DeviceCode: "aDeviceCode", @@ -217,7 +210,6 @@ func TestWaitForDeviceToken(t *testing.T) { BaseURL: ts.URL, ClientID: "aClientID", Scopes: []string{"bork", "meow"}, - Client: util.Client{}, } state := State{ DeviceCode: "aDeviceCode", @@ -236,3 +228,117 @@ func TestWaitForDeviceToken(t *testing.T) { assert.ErrorContains(t, err, "context canceled") }) } + +func TestRefresh(t *testing.T) { + t.Run("success", func(t *testing.T) { + expectedToken := TokenResponse{ + AccessToken: "a-real-token", + IDToken: "", + RefreshToken: "the-refresh-token", + Scope: "", + ExpiresIn: 3600, + TokenType: "", + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/token", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("refresh_token"), "v1.a-refresh-token") + assert.Equal(t, r.FormValue("grant_type"), "refresh_token") + + jsonState, err := json.Marshal(expectedToken) + assert.NilError(t, err) + w.Write(jsonState) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + + token, err := api.Refresh(context.Background(), "v1.a-refresh-token") + assert.NilError(t, err) + + assert.DeepEqual(t, token, expectedToken) + }) + + t.Run("canceled context", func(t *testing.T) { + expectedToken := TokenResponse{ + AccessToken: "a-real-token", + IDToken: "", + RefreshToken: "the-refresh-token", + Scope: "", + ExpiresIn: 3600, + TokenType: "", + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/token", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("refresh_token"), "v1.a-refresh-token") + assert.Equal(t, r.FormValue("grant_type"), "refresh_token") + + jsonState, err := json.Marshal(expectedToken) + assert.NilError(t, err) + w.Write(jsonState) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := api.Refresh(ctx, "v1.a-refresh-token") + + assert.ErrorContains(t, err, "context canceled") + }) +} + +func TestRevoke(t *testing.T) { + t.Run("success", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/revoke", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token") + + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + + err := api.Revoke(context.Background(), "v1.a-refresh-token") + assert.NilError(t, err) + }) + + t.Run("canceled context", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/revoke", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("token"), "v1.a-refresh-token") + + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := api.Revoke(ctx, "v1.a-refresh-token") + + assert.ErrorContains(t, err, "context canceled") + }) +} diff --git a/cli/internal/oauth/manager/manager.go b/cli/internal/oauth/manager/manager.go index c5477d867114..34fabcbcbcdc 100644 --- a/cli/internal/oauth/manager/manager.go +++ b/cli/internal/oauth/manager/manager.go @@ -22,7 +22,7 @@ type OAuthManager struct { openBrowser func(string) error } -// OAuthManagerOptions is the options used for New to create a new auth manager. +// OAuthManagerOptions are the options used for New to create a new auth manager. type OAuthManagerOptions struct { Audience string ClientID string @@ -49,18 +49,16 @@ func New(options OAuthManagerOptions) *OAuthManager { BaseURL: "https://" + options.Tenant, ClientID: options.ClientID, Scopes: scopes, - Client: util.Client{ - UserAgent: options.DeviceName, - }, }, tenant: options.Tenant, openBrowser: openBrowser, } } -// LoginDevice launches the device authentication flow with the tenant, printing instructions -// to the provided writer and attempting to open the browser for the user to authenticate. -// Once complete, the retrieved tokens are stored and returned. +// LoginDevice launches the device authentication flow with the tenant, printing +// instructions to the provided writer and attempting to open the browser for +// the user to authenticate. +// Once complete, the retrieved tokens are returned. func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (res oauth.TokenResult, err error) { state, err := m.api.GetDeviceCode(ctx, m.audience) if err != nil { @@ -75,7 +73,7 @@ func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (res oauth. _, _ = fmt.Fprintln(w, "To sign in with credentials on the command line, use 'docker login -u '") _, _ = fmt.Fprintf(w, "\nYour one-time device confirmation code is: %s\n", state.UserCode) _, _ = fmt.Fprint(w, "\nPress ENTER to open the browser.\n") - _, _ = fmt.Fprintf(w, "Or open the URL manually: %s.\n", strings.Split(state.VerificationURI, "?")[0]) + _, _ = fmt.Fprintf(w, "Or open the URL manually: %s\n", strings.Split(state.VerificationURI, "?")[0]) tokenResChan := make(chan api.TokenResponse) waitForTokenErrChan := make(chan error) @@ -117,25 +115,14 @@ func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (res oauth. return res, nil } -// Logout logs out of the session for the client and removes tokens from the storage provider. -func (m *OAuthManager) Logout(ctx context.Context) error { - return errors.Join( - m.openBrowser(m.api.LogoutURL()), - ) +// Logout revokes the provided refresh token with the oauth tenant. However, it +// does not end the user's session with the tenant. +func (m *OAuthManager) Logout(ctx context.Context, refreshToken string) error { + return m.api.Revoke(ctx, refreshToken) } -var ( - // ErrNoCreds is returned by RefreshToken when the store does not contain credentials - // for the official registry. - ErrNoCreds = errors.New("no credentials found") - - // ErrUnexpiredToken is returned by RefreshToken when the token is not expired. - ErrUnexpiredToken = errors.New("token is not expired") -) - -// RefreshToken fetches credentials from the store, refreshes them, stores the new tokens -// and returns them. -// If there are no credentials in the store, ErrNoCreds is returned. +// RefreshToken uses the provided token to refresh access with the oauth +// tenant, returning new access and refresh token. func (m OAuthManager) RefreshToken(ctx context.Context, refreshToken string) (res oauth.TokenResult, err error) { refreshRes, err := m.api.Refresh(ctx, refreshToken) if err != nil { diff --git a/cli/internal/oauth/manager/manager_test.go b/cli/internal/oauth/manager/manager_test.go index 9945605f04cb..4bd0adb0a279 100644 --- a/cli/internal/oauth/manager/manager_test.go +++ b/cli/internal/oauth/manager/manager_test.go @@ -176,22 +176,21 @@ func TestLoginDevice(t *testing.T) { } func TestLogout(t *testing.T) { + var receivedToken string a := &testAPI{ - logoutURL: "test-logout-url", + revokeToken: func(token string) error { + receivedToken = token + return nil + }, } - var browserOpenURL string manager := OAuthManager{ api: a, - openBrowser: func(url string) error { - browserOpenURL = url - return nil - }, } - err := manager.Logout(context.Background()) + err := manager.Logout(context.Background(), "a-refresh-token") assert.NilError(t, err) - assert.Equal(t, browserOpenURL, "test-logout-url") + assert.Equal(t, receivedToken, "a-refresh-token") } func TestRefreshToken(t *testing.T) { @@ -240,10 +239,10 @@ func TestRefreshToken(t *testing.T) { var _ api.OAuthAPI = &testAPI{} type testAPI struct { - logoutURL string getDeviceToken func(audience string) (api.State, error) waitForDeviceToken func(state api.State) (api.TokenResponse, error) refresh func(token string) (api.TokenResponse, error) + revokeToken func(token string) error } func (t *testAPI) GetDeviceCode(_ context.Context, audience string) (api.State, error) { @@ -267,6 +266,9 @@ func (t *testAPI) Refresh(_ context.Context, token string) (api.TokenResponse, e return api.TokenResponse{}, nil } -func (t *testAPI) LogoutURL() string { - return t.logoutURL +func (t *testAPI) Revoke(_ context.Context, token string) error { + if t.revokeToken != nil { + return t.revokeToken(token) + } + return nil } diff --git a/cli/internal/oauth/util/browser.go b/cli/internal/oauth/util/browser.go new file mode 100644 index 000000000000..db71404d4642 --- /dev/null +++ b/cli/internal/oauth/util/browser.go @@ -0,0 +1,38 @@ +package util + +import ( + "errors" + "os/exec" + "runtime" + "strings" +) + +// https://github.com/docker/pinata/blob/675a1c7d8ae965bb44a0679ff9ff7a108b82b9e0/common/cmd/com.docker.backend/internal/auth/browser/browser.go#L27 +// OpenBrowser opens the specified URL in a browser based on OS. +func OpenBrowser(url string) error { + switch runtime.GOOS { + case "linux": + return openBrowserLinux(url) + case "windows": + return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + return exec.Command("open", url).Start() + default: + return errors.New("unsupported platform") + } +} + +func openBrowserLinux(url string) error { + providers := []string{"wslview", "xdg-open", "x-www-browser", "www-browser"} + + // There are multiple possible providers to open a browser on linux + // One of them is xdg-open, another is x-www-browser, then there's www-browser, etc. + // Look for one that exists and run it + for _, provider := range providers { + if _, err := exec.LookPath(provider); err == nil { + return exec.Command(provider, url).Start() + } + } + + return &exec.Error{Name: strings.Join(providers, ","), Err: exec.ErrNotFound} +} diff --git a/cli/internal/oauth/util/client.go b/cli/internal/oauth/util/client.go index d11fa7fd1f52..d91b9c4b3b24 100644 --- a/cli/internal/oauth/util/client.go +++ b/cli/internal/oauth/util/client.go @@ -2,28 +2,16 @@ package util import ( "context" + "fmt" "io" "net/http" -) - -// Client is a client and actions for interacting with the tenant auth API. -type Client struct { - UserAgent string -} - -// setHeaders sets common headers for requests. -func (c Client) setHeaders(req *http.Request, isForm bool) { - if isForm { - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - } + "runtime" + "strings" - if c.UserAgent != "" { - req.Header.Set("User-Agent", c.UserAgent) - } -} + "github.com/docker/cli/cli/version" +) -// PostForm does a POST request with form data. -func (c Client) PostForm(ctx context.Context, url string, data io.Reader) (*http.Response, error) { +func PostForm(ctx context.Context, url string, data io.Reader) (*http.Response, error) { client := http.Client{} req, err := http.NewRequest(http.MethodPost, url, data) @@ -32,21 +20,9 @@ func (c Client) PostForm(ctx context.Context, url string, data io.Reader) (*http } req = req.WithContext(ctx) - c.setHeaders(req, true) - - return client.Do(req) -} - -// Post does a POST with the specified data. -func (c Client) Post(url string, data io.Reader) (*http.Response, error) { - client := http.Client{} - - req, err := http.NewRequest(http.MethodPost, url, data) - if err != nil { - return nil, err - } - - c.setHeaders(req, false) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + cliVersion := strings.ReplaceAll(version.Version, ".", "_") + req.Header.Set("User-Agent", fmt.Sprintf("docker-cli:%s:%s-%s", cliVersion, runtime.GOOS, runtime.GOARCH)) return client.Do(req) } diff --git a/cli/internal/oauth/util/util.go b/cli/internal/oauth/util/util.go deleted file mode 100644 index ac18953af008..000000000000 --- a/cli/internal/oauth/util/util.go +++ /dev/null @@ -1,23 +0,0 @@ -package util - -import ( - "errors" - "os/exec" - "runtime" -) - -// OpenBrowser opens the specified URL in a browser based on OS. -func OpenBrowser(url string) (err error) { - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - default: - err = errors.New("unsupported platform") - } - - return -} diff --git a/cli/oauth/manager.go b/cli/oauth/manager.go index e0e3be283aba..56389e918a4a 100644 --- a/cli/oauth/manager.go +++ b/cli/oauth/manager.go @@ -16,6 +16,6 @@ type TokenResult struct { type Manager interface { LoginDevice(ctx context.Context, out io.Writer) (TokenResult, error) - Logout(ctx context.Context) error + Logout(ctx context.Context, refreshToken string) error RefreshToken(ctx context.Context, refreshToken string) (TokenResult, error) } diff --git a/internal/test/cli.go b/internal/test/cli.go index ebd3ffbb16df..080e4037952b 100644 --- a/internal/test/cli.go +++ b/internal/test/cli.go @@ -224,7 +224,7 @@ func (f *fakeOauthManager) LoginDevice(ctx context.Context, w io.Writer) (res oa return res, nil } -func (f *fakeOauthManager) Logout(ctx context.Context) error { +func (f *fakeOauthManager) Logout(ctx context.Context, refreshToken string) error { return nil }