diff --git a/cli/command/cli.go b/cli/command/cli.go index 06478c9c776e..b93c80b3a25c 100644 --- a/cli/command/cli.go +++ b/cli/command/cli.go @@ -21,7 +21,9 @@ import ( "github.com/docker/cli/cli/context/store" "github.com/docker/cli/cli/debug" cliflags "github.com/docker/cli/cli/flags" + "github.com/docker/cli/cli/internal/oauth/manager" manifeststore "github.com/docker/cli/cli/manifest/store" + "github.com/docker/cli/cli/oauth" registryclient "github.com/docker/cli/cli/registry/client" "github.com/docker/cli/cli/streams" "github.com/docker/cli/cli/trust" @@ -66,6 +68,7 @@ type Cli interface { CurrentContext() string DockerEndpoint() docker.Endpoint TelemetryClient + OAuthManager() oauth.Manager } // DockerCli is an instance the docker command line client. @@ -86,6 +89,7 @@ type DockerCli struct { dockerEndpoint docker.Endpoint contextStoreConfig store.Config initTimeout time.Duration + oauthManager oauth.Manager res telemetryResource // baseCtx is the base context used for internal operations. In the future @@ -96,6 +100,10 @@ type DockerCli struct { enableGlobalMeter, enableGlobalTracer bool } +func (cli *DockerCli) OAuthManager() oauth.Manager { + return cli.oauthManager +} + // DefaultVersion returns api.defaultVersion. func (cli *DockerCli) DefaultVersion() string { return api.DefaultVersion @@ -293,6 +301,8 @@ func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions, ops ...CLIOption) cli.createGlobalTracerProvider(cli.baseCtx) } + cli.oauthManager = manager.NewManager() + return nil } diff --git a/cli/command/registry/login.go b/cli/command/registry/login.go index 9f33a678fbd8..068777604bab 100644 --- a/cli/command/registry/login.go +++ b/cli/command/registry/login.go @@ -121,9 +121,39 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err response, err = loginWithCredStoreCreds(ctx, dockerCli, &authConfig) } if err != nil || authConfig.Username == "" || authConfig.Password == "" { - err = command.ConfigureAuth(ctx, dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry) - if err != nil { - return err + if isDefaultRegistry && opts.user == "" && opts.password == "" { + // todo(laurazard: clean this up + tokenRes, err := dockerCli.OAuthManager().LoginDevice(ctx, dockerCli.Err()) + if err != nil { + return err + } + authConfig.Username = tokenRes.Claims.Domain.Username + authConfig.Password = tokenRes.AccessToken + authConfig.Email = tokenRes.Claims.Domain.Email + authConfig.ServerAddress = serverAddress + + response, err = clnt.RegistryLogin(ctx, authConfig) + if err != nil && client.IsErrConnectionFailed(err) { + // If the server isn't responding (yet) attempt to login purely client side + response, err = loginClientSide(ctx, authConfig) + } + // If we (still) have an error, give up + if err != nil { + return err + } + + authConfig.Password = authConfig.Password + ".." + tokenRes.RefreshToken + + creds := dockerCli.ConfigFile().GetCredentialsStore(serverAddress) + if err := creds.Store(configtypes.AuthConfig(authConfig)); err != nil { + return errors.Errorf("Error saving credentials: %v", err) + } + return nil + } else { + err = command.ConfigureAuth(ctx, dockerCli, opts.user, opts.password, &authConfig, isDefaultRegistry) + if err != nil { + return err + } } response, err = clnt.RegistryLogin(ctx, authConfig) diff --git a/cli/config/configfile/file.go b/cli/config/configfile/file.go index eba8a63956c8..8f2e0b088cfb 100644 --- a/cli/config/configfile/file.go +++ b/cli/config/configfile/file.go @@ -10,6 +10,7 @@ import ( "github.com/docker/cli/cli/config/credentials" "github.com/docker/cli/cli/config/types" + "github.com/docker/cli/cli/internal/oauth/manager" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -260,7 +261,7 @@ func (configFile *ConfigFile) GetCredentialsStore(registryHostname string) crede } else { credsStore = credentials.NewFileStore(configFile) } - return credentials.NewOAuthStore(credsStore) + return credentials.NewOAuthStore(credsStore, manager.NewManager()) } // var for unit testing. diff --git a/cli/config/credentials/oauth_store.go b/cli/config/credentials/oauth_store.go index cd87ce4e9a0b..e6d7d24c7c4c 100644 --- a/cli/config/credentials/oauth_store.go +++ b/cli/config/credentials/oauth_store.go @@ -3,11 +3,9 @@ package credentials import ( "context" "errors" - "os" "strings" "time" - "github.com/docker/cli/cli/config/credentials/internal/oauth/manager" "github.com/docker/cli/cli/config/types" "github.com/docker/cli/cli/oauth" "github.com/docker/docker/registry" @@ -22,11 +20,10 @@ type oauthStore struct { } // NewOAuthStore creates a new oauthStore backed by the provided store. -func NewOAuthStore(backingStore Store) Store { - m, _ := manager.NewManager() +func NewOAuthStore(backingStore Store, manager oauth.Manager) Store { return &oauthStore{ backingStore: backingStore, - manager: m, + manager: manager, } } @@ -51,31 +48,20 @@ func (c *oauthStore) Get(serverAddress string) (types.AuthConfig, error) { // store itself. This should be propagated up. return types.AuthConfig{}, err } + tokenRes, err := c.parseToken(auth.Password) - if err != nil && auth.Password != "" { - return types.AuthConfig{ - Username: auth.Username, - Password: auth.Password, - Email: auth.Email, - ServerAddress: registry.IndexServer, - }, nil + // if the password is not a token, return the auth config as is + if err != nil { + return auth, nil } - var failedRefresh bool // if the access token is valid for less than 50 minutes, refresh it if tokenRes.RefreshToken != "" && tokenRes.Claims.Expiry.Time().Before(time.Now().Add(minimumTokenLifetime)) { refreshRes, err := c.manager.RefreshToken(context.TODO(), tokenRes.RefreshToken) - if err != nil { - failedRefresh = true - } - tokenRes = refreshRes - } - - if tokenRes.AccessToken == "" || failedRefresh { - tokenRes, err = c.manager.LoginDevice(context.TODO(), os.Stderr) if err != nil { return types.AuthConfig{}, err } + tokenRes = refreshRes } err = c.storeInBackingStore(tokenRes) @@ -121,19 +107,9 @@ func (c *oauthStore) Erase(serverAddress string) error { return c.backingStore.Erase(serverAddress) } -// Store stores the provided credentials in the backing credential store, -// except when the credentials are for the official registry, in which case -// no action is taken because the credentials retrieved/stored during Get. +// Store stores the provided credentials in the backing store, without any +// additional processing. func (c *oauthStore) Store(auth types.AuthConfig) error { - if auth.ServerAddress != registry.IndexServer { - return c.backingStore.Store(auth) - } - - _, err := c.parseToken(auth.Password) - if err == nil { - return nil - } - return c.backingStore.Store(auth) } diff --git a/cli/config/credentials/oauth_store_test.go b/cli/config/credentials/oauth_store_test.go index ae7dd497a6c2..19287f8edc9b 100644 --- a/cli/config/credentials/oauth_store_test.go +++ b/cli/config/credentials/oauth_store_test.go @@ -48,43 +48,18 @@ func TestOAuthStoreGet(t *testing.T) { }) }) - t.Run("no credentials - login", func(t *testing.T) { + t.Run("no credentials - return", func(t *testing.T) { auths := map[string]types.AuthConfig{} f := newStore(auths) - manager := &testManager{ - loginDevice: func() (oauth.TokenResult, error) { - return oauth.TokenResult{ - AccessToken: "abcd1234", - RefreshToken: "efgh5678", - Claims: oauth.Claims{ - Claims: jwt.Claims{ - Expiry: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - }, - Domain: oauth.DomainClaims{Username: "bork!", Email: "bork@docker.com"}, - }, - }, nil - }, - } s := &oauthStore{ backingStore: NewFileStore(f), - manager: manager, } auth, err := s.Get(registry.IndexServer) assert.NilError(t, err) - assert.DeepEqual(t, auth, types.AuthConfig{ - Username: "bork!", - Password: "abcd1234", - Email: "bork@docker.com", - ServerAddress: registry.IndexServer, - }) - assert.DeepEqual(t, auths[registry.IndexServer], types.AuthConfig{ - Username: "bork!", - Password: "abcd1234..efgh5678", - Email: "bork@docker.com", - ServerAddress: registry.IndexServer, - }) + assert.DeepEqual(t, auth, types.AuthConfig{}) + assert.Equal(t, len(auths), 0) }) t.Run("expired credentials - refresh", func(t *testing.T) { @@ -135,7 +110,7 @@ func TestOAuthStoreGet(t *testing.T) { }) }) - t.Run("expired credentials - refresh fails - login", func(t *testing.T) { + t.Run("expired credentials - refresh fails - return error", func(t *testing.T) { f := newStore(map[string]types.AuthConfig{ registry.IndexServer: { Username: "bork!", @@ -144,23 +119,11 @@ func TestOAuthStoreGet(t *testing.T) { ServerAddress: registry.IndexServer, }, }) - var loginCalled bool + var refreshCalled bool manager := &testManager{ refresh: func(_ string) (oauth.TokenResult, error) { - return oauth.TokenResult{}, errors.New("program failed") - }, - loginDevice: func() (oauth.TokenResult, error) { - loginCalled = true - return oauth.TokenResult{ - AccessToken: "abcd1234", - RefreshToken: "efgh5678", - Claims: oauth.Claims{ - Claims: jwt.Claims{ - Expiry: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)), - }, - Domain: oauth.DomainClaims{Username: "bork!", Email: "bork@docker.com"}, - }, - }, nil + refreshCalled = true + return oauth.TokenResult{}, errors.New("refresh failed") }, } s := &oauthStore{ @@ -168,22 +131,10 @@ func TestOAuthStoreGet(t *testing.T) { manager: manager, } - auth, err := s.Get(registry.IndexServer) - assert.NilError(t, err) + _, err := s.Get(registry.IndexServer) + assert.ErrorContains(t, err, "refresh failed") - assert.Check(t, loginCalled) - assert.DeepEqual(t, auth, types.AuthConfig{ - Username: "bork!", - Password: "abcd1234", - Email: "bork@docker.com", - ServerAddress: registry.IndexServer, - }) - assert.DeepEqual(t, f.GetAuthConfigs()[registry.IndexServer], types.AuthConfig{ - Username: "bork!", - Password: "abcd1234..efgh5678", - Email: "bork@docker.com", - ServerAddress: registry.IndexServer, - }) + assert.Check(t, refreshCalled) }) t.Run("old non-access token credentials", func(t *testing.T) { @@ -386,6 +337,7 @@ func TestStore(t *testing.T) { } err := s.Store(auth) assert.NilError(t, err) + assert.Check(t, is.Len(f.GetAuthConfigs(), 1)) }) @@ -402,7 +354,14 @@ func TestStore(t *testing.T) { } err := s.Store(auth) assert.NilError(t, err) - assert.Check(t, is.Len(f.GetAuthConfigs(), 0)) + + assert.Check(t, is.Len(f.GetAuthConfigs(), 1)) + assert.DeepEqual(t, f.GetAuthConfigs()[registry.IndexServer], types.AuthConfig{ + Username: "foo", + Password: validNotExpiredToken + "..refresh-token", + Email: "foo@example.com", + ServerAddress: registry.IndexServer, + }) }) }) diff --git a/cli/internal/oauth/api/api.go b/cli/internal/oauth/api/api.go new file mode 100644 index 000000000000..1f3f1430067c --- /dev/null +++ b/cli/internal/oauth/api/api.go @@ -0,0 +1,153 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/docker/cli/cli/internal/oauth/util" +) + +type OAuthAPI interface { + GetDeviceCode(ctx context.Context, audience string) (State, error) + WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) + Refresh(ctx context.Context, token string) (TokenResponse, error) + LogoutURL() string +} + +// API represents API interactions with Auth0. +type API struct { + // BaseURL is the base used for each request to Auth0. + BaseURL string + // ClientID is the client ID for the application to auth with the tenant. + ClientID string + // Scopes are the scopes that are requested during the device auth flow. + Scopes []string + // Client is the client that is used for calls. + Client util.Client +} + +// TokenResponse represents the response of the /oauth/token route. +type TokenResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Error *string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +var ErrTimeout = errors.New("timed out waiting for device token") + +// GetDeviceCode returns device code authorization information from Auth0. +func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) { + data := url.Values{ + "client_id": {a.ClientID}, + "audience": {audience}, + "scope": {strings.Join(a.Scopes, " ")}, + } + + deviceCodeURL := a.BaseURL + "/oauth/device/code" + resp, err := a.Client.PostForm(deviceCodeURL, strings.NewReader(data.Encode())) + if err != nil { + return + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + var body map[string]any + err = json.NewDecoder(resp.Body).Decode(&body) + if errorDescription, ok := body["error_description"].(string); ok { + return state, errors.New(errorDescription) + } + return state, fmt.Errorf("failed to get device code: %w", err) + } + + err = json.NewDecoder(resp.Body).Decode(&state) + + return +} + +// WaitForDeviceToken polls to get tokens based on the device code set up. This +// only works in a device auth flow. +func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) { + ticker := time.NewTicker(state.IntervalDuration()) + timeout := time.After(time.Duration(state.ExpiresIn) * time.Second) + + for { + select { + case <-ticker.C: + res, err := a.getDeviceToken(state) + if err != nil { + return res, err + } + + if res.Error != nil { + if *res.Error == "authorization_pending" { + continue + } + + return res, errors.New(res.ErrorDescription) + } + + return res, nil + case <-timeout: + ticker.Stop() + return TokenResponse{}, ErrTimeout + } + } +} + +// getToken calls the token endpoint of Auth0 and returns the response. +func (a API) getDeviceToken(state State) (res TokenResponse, err error) { + data := url.Values{ + "client_id": {a.ClientID}, + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "device_code": {state.DeviceCode}, + } + oauthTokenURL := a.BaseURL + "/oauth/token" + + resp, err := a.Client.PostForm(oauthTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return res, fmt.Errorf("failed to get code: %w", err) + } + + err = json.NewDecoder(resp.Body).Decode(&res) + _ = resp.Body.Close() + + return +} + +// Refresh returns new tokens based on the refresh token. +func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err error) { + data := url.Values{ + "grant_type": {"refresh_token"}, + "client_id": {a.ClientID}, + "refresh_token": {token}, + } + + refreshURL := a.BaseURL + "/oauth/token" + //nolint:gosec // Ignore G107: Potential HTTP request made with variable url + resp, err := http.PostForm(refreshURL, data) + if err != nil { + return + } + + err = json.NewDecoder(resp.Body).Decode(&res) + _ = resp.Body.Close() + + return +} + +func (a API) LogoutURL() string { + return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID) +} diff --git a/cli/internal/oauth/api/api_test.go b/cli/internal/oauth/api/api_test.go new file mode 100644 index 000000000000..78142afafc0d --- /dev/null +++ b/cli/internal/oauth/api/api_test.go @@ -0,0 +1,188 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/docker/cli/cli/internal/oauth/util" + "gotest.tools/v3/assert" +) + +func TestGetDeviceCode(t *testing.T) { + t.Run("success", func(t *testing.T) { + var clientID, audience, scope, path string + expectedState := State{ + DeviceCode: "aDeviceCode", + UserCode: "aUserCode", + VerificationURI: "aVerificationURI", + ExpiresIn: 60, + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + clientID = r.FormValue("client_id") + audience = r.FormValue("audience") + scope = r.FormValue("scope") + path = r.URL.Path + + jsonState, err := json.Marshal(expectedState) + assert.NilError(t, err) + + _, _ = w.Write(jsonState) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + Client: util.Client{}, + } + + state, err := api.GetDeviceCode(context.Background(), "anAudience") + assert.NilError(t, err) + + assert.DeepEqual(t, expectedState, state) + assert.Equal(t, clientID, "aClientID") + assert.Equal(t, audience, "anAudience") + assert.Equal(t, scope, "bork meow") + assert.Equal(t, path, "/oauth/device/code") + }) + + t.Run("error w/ description", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonState, err := json.Marshal(TokenResponse{ + ErrorDescription: "invalid audience", + }) + assert.NilError(t, err) + + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(jsonState) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + Client: util.Client{}, + } + + _, err := api.GetDeviceCode(context.Background(), "bad_audience") + + assert.ErrorContains(t, err, "invalid audience") + }) + + t.Run("general error", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "an error", http.StatusInternalServerError) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + Client: util.Client{}, + } + + _, err := api.GetDeviceCode(context.Background(), "anAudience") + + assert.ErrorContains(t, err, "failed to get device code") + }) + + // todo(laurazard): test + t.Run("canceled context", func(t *testing.T) {}) +} + +func TestWaitForDeviceToken(t *testing.T) { + t.Run("success", func(t *testing.T) { + expectedToken := TokenResponse{ + AccessToken: "a-real-token", + IDToken: "", + RefreshToken: "the-refresh-token", + Scope: "", + ExpiresIn: 3600, + TokenType: "", + } + var respond bool + go func() { + time.Sleep(10 * time.Second) + respond = true + }() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/token", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code") + assert.Equal(t, r.FormValue("device_code"), "aDeviceCode") + + if respond { + jsonState, err := json.Marshal(expectedToken) + assert.NilError(t, err) + w.Write(jsonState) + } else { + pendingError := "authorization_pending" + jsonResponse, err := json.Marshal(TokenResponse{ + Error: &pendingError, + }) + assert.NilError(t, err) + w.Write(jsonResponse) + } + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + Client: util.Client{}, + } + state := State{ + DeviceCode: "aDeviceCode", + UserCode: "aUserCode", + Interval: 1, + ExpiresIn: 30, + } + token, err := api.WaitForDeviceToken(context.Background(), state) + assert.NilError(t, err) + + assert.DeepEqual(t, token, expectedToken) + }) + + t.Run("timeout", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth/token", r.URL.Path) + assert.Equal(t, r.FormValue("client_id"), "aClientID") + assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code") + assert.Equal(t, r.FormValue("device_code"), "aDeviceCode") + + pendingError := "authorization_pending" + jsonResponse, err := json.Marshal(TokenResponse{ + Error: &pendingError, + }) + assert.NilError(t, err) + w.Write(jsonResponse) + })) + defer ts.Close() + api := API{ + BaseURL: ts.URL, + ClientID: "aClientID", + Scopes: []string{"bork", "meow"}, + Client: util.Client{}, + } + state := State{ + DeviceCode: "aDeviceCode", + UserCode: "aUserCode", + Interval: 1, + ExpiresIn: 1, + } + + _, err := api.WaitForDeviceToken(context.Background(), state) + + assert.ErrorIs(t, err, ErrTimeout) + }) + + // todo(laurazard): test + t.Run("canceled context", func(t *testing.T) {}) +} diff --git a/cli/internal/oauth/api/state.go b/cli/internal/oauth/api/state.go new file mode 100644 index 000000000000..e6f5397d688a --- /dev/null +++ b/cli/internal/oauth/api/state.go @@ -0,0 +1,20 @@ +package api + +import ( + "time" +) + +// State represents the state of exchange after submitting. +type State struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// IntervalDuration returns the duration that should be waited between each auth +// polling event. +func (s State) IntervalDuration() time.Duration { + return time.Second * time.Duration(s.Interval) +} diff --git a/cli/internal/oauth/manager/manager.go b/cli/internal/oauth/manager/manager.go new file mode 100644 index 000000000000..689bb8d87a4c --- /dev/null +++ b/cli/internal/oauth/manager/manager.go @@ -0,0 +1,162 @@ +package manager + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "os" + "strings" + + "github.com/docker/cli/cli/internal/oauth/api" + "github.com/docker/cli/cli/internal/oauth/util" + "github.com/docker/cli/cli/oauth" +) + +// OAuthManager is the manager +type OAuthManager struct { + api api.OAuthAPI + audience string + tenant string + openBrowser func(string) error +} + +// OAuthManagerOptions is the options used for New to create a new auth manager. +type OAuthManagerOptions struct { + Audience string + ClientID string + Scopes []string + Tenant string + DeviceName string + OpenBrowser func(string) error +} + +func New(options OAuthManagerOptions) *OAuthManager { + scopes := []string{"openid", "offline_access"} + if len(options.Scopes) > 0 { + scopes = options.Scopes + } + + openBrowser := util.OpenBrowser + if options.OpenBrowser != nil { + openBrowser = options.OpenBrowser + } + + return &OAuthManager{ + audience: options.Audience, + api: api.API{ + BaseURL: "https://" + options.Tenant, + ClientID: options.ClientID, + Scopes: scopes, + Client: util.Client{ + UserAgent: options.DeviceName, + }, + }, + tenant: options.Tenant, + openBrowser: openBrowser, + } +} + +// LoginDevice launches the device authentication flow with the tenant, printing instructions +// to the provided writer and attempting to open the browser for the user to authenticate. +// Once complete, the retrieved tokens are stored and returned. +func (m *OAuthManager) LoginDevice(ctx context.Context, w io.Writer) (res oauth.TokenResult, err error) { + state, err := m.api.GetDeviceCode(ctx, m.audience) + if err != nil { + return res, fmt.Errorf("login failed: %w", err) + } + + if state.UserCode == "" { + return res, errors.New("login failed: no user code returned") + } + + _, _ = fmt.Fprintln(w, "\nYou will be signed in using a web-based login.") + _, _ = fmt.Fprintln(w, "To sign in with credentials on the command line, use 'docker login -u '") + _, _ = fmt.Fprintf(w, "\nYour one-time device confirmation code is: %s\n", state.UserCode) + _, _ = fmt.Fprint(w, "\nPress ENTER to open the browser.\n") + _, _ = fmt.Fprintf(w, "Or open the URL manually: %s.\n", strings.Split(state.VerificationURI, "?")[0]) + + tokenResChan := make(chan api.TokenResponse) + waitForTokenErrChan := make(chan error) + go func() { + tokenRes, err := m.api.WaitForDeviceToken(ctx, state) + if err != nil { + waitForTokenErrChan <- err + return + } + tokenResChan <- tokenRes + }() + + go func() { + reader := bufio.NewReader(os.Stdin) + reader.ReadString('\n') + _ = m.openBrowser(state.VerificationURI) + }() + + _, _ = fmt.Fprint(w, "\nWaiting for authentication in the browser...\n") + var tokenRes api.TokenResponse + select { + case <-ctx.Done(): + return res, errors.New("login canceled") + case err := <-waitForTokenErrChan: + return res, fmt.Errorf("login failed: %w", err) + case tokenRes = <-tokenResChan: + } + + claims, err := oauth.GetClaims(tokenRes.AccessToken) + if err != nil { + return res, fmt.Errorf("login failed: %w", err) + } + + res.Tenant = m.tenant + res.AccessToken = tokenRes.AccessToken + res.RefreshToken = tokenRes.RefreshToken + res.Claims = claims + + return res, nil +} + +// Logout logs out of the session for the client and removes tokens from the storage provider. +func (m *OAuthManager) Logout(ctx context.Context) error { + return errors.Join( + m.openBrowser(m.api.LogoutURL()), + ) +} + +var ( + // ErrNoCreds is returned by RefreshToken when the store does not contain credentials + // for the official registry. + ErrNoCreds = errors.New("no credentials found") + + // ErrUnexpiredToken is returned by RefreshToken when the token is not expired. + ErrUnexpiredToken = errors.New("token is not expired") +) + +// RefreshToken fetches credentials from the store, refreshes them, stores the new tokens +// and returns them. +// If there are no credentials in the store, ErrNoCreds is returned. +func (m OAuthManager) RefreshToken(ctx context.Context, refreshToken string) (res oauth.TokenResult, err error) { + refreshRes, err := m.api.Refresh(ctx, refreshToken) + if err != nil { + return res, err + } + + // todo(laurazard) + // select { + // case <-ctx.Done(): + // return "", ctx.Err() + // default: + // } + + claims, err := oauth.GetClaims(refreshRes.AccessToken) + if err != nil { + return res, err + } + + res.Tenant = m.tenant + res.AccessToken = refreshRes.AccessToken + res.RefreshToken = refreshRes.RefreshToken + res.Claims = claims + return res, nil +} diff --git a/cli/internal/oauth/manager/manager_test.go b/cli/internal/oauth/manager/manager_test.go new file mode 100644 index 000000000000..2243ae9fde43 --- /dev/null +++ b/cli/internal/oauth/manager/manager_test.go @@ -0,0 +1,229 @@ +package manager + +import ( + "context" + "os" + "testing" + + "github.com/docker/cli/cli/internal/oauth/api" + "github.com/docker/cli/cli/oauth" + "github.com/go-jose/go-jose/v3/jwt" + "gotest.tools/v3/assert" +) + +const ( + //nolint:lll + validToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6InhYa3BCdDNyV3MyRy11YjlscEpncSJ9.eyJodHRwczovL2h1Yi5kb2NrZXIuY29tIjp7ImVtYWlsIjoiYm9ya0Bkb2NrZXIuY29tIiwic2Vzc2lvbl9pZCI6ImEtc2Vzc2lvbi1pZCIsInNvdXJjZSI6InNhbWxwIiwidXNlcm5hbWUiOiJib3JrISIsInV1aWQiOiIwMTIzLTQ1Njc4OSJ9LCJpc3MiOiJodHRwczovL2xvZ2luLmRvY2tlci5jb20vIiwic3ViIjoic2FtbHB8c2FtbHAtZG9ja2VyfGJvcmtAZG9ja2VyLmNvbSIsImF1ZCI6WyJodHRwczovL2F1ZGllbmNlLmNvbSJdLCJpYXQiOjE3MTk1MDI5MzksImV4cCI6MTcxOTUwNjUzOSwic2NvcGUiOiJvcGVuaWQgb2ZmbGluZV9hY2Nlc3MifQ.VUSp-9_SOvMPWJPRrSh7p4kSPoye4DA3kyd2I0TW0QtxYSRq7xCzNj0NC_ywlPlKBFBeXKm4mh93d1vBSh79I9Heq5tj0Fr4KH77U5xJRMEpjHqoT5jxMEU1hYXX92xctnagBMXxDvzUfu3Yf0tvYSA0RRoGbGTHfdYYRwOrGbwQ75Qg1dyIxUkwsG053eYX2XkmLGxymEMgIq_gWksgAamOc40_0OCdGr-MmDeD2HyGUa309aGltzQUw7Z0zG1AKSXy3WwfMHdWNFioTAvQphwEyY3US8ybSJi78upSFTjwUcryMeHUwQ3uV9PxwPMyPoYxo1izVB-OUJxM8RqEbg" + //nolint:lll + newerToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6InhYa3BCdDNyV3MyRy11YjlscEpncSJ9.eyJodHRwczovL2h1Yi5kb2NrZXIuY29tIjp7ImVtYWlsIjoiYm9ya0Bkb2NrZXIuY29tIiwic2Vzc2lvbl9pZCI6ImEtc2Vzc2lvbi1pZCIsInNvdXJjZSI6InNhbWxwIiwidXNlcm5hbWUiOiJib3JrISIsInV1aWQiOiIwMTIzLTQ1Njc4OSJ9LCJpc3MiOiJodHRwczovL2xvZ2luLmRvY2tlci5jb20vIiwic3ViIjoic2FtbHB8c2FtbHAtZG9ja2VyfGJvcmtAZG9ja2VyLmNvbSIsImF1ZCI6WyJodHRwczovL2F1ZGllbmNlLmNvbSJdLCJpYXQiOjI3MTk1MDI5MzksImV4cCI6MjcxOTUwNjUzOSwic2NvcGUiOiJvcGVuaWQgb2ZmbGluZV9hY2Nlc3MifQ.VUSp-9_SOvMPWJPRrSh7p4kSPoye4DA3kyd2I0TW0QtxYSRq7xCzNj0NC_ywlPlKBFBeXKm4mh93d1vBSh79I9Heq5tj0Fr4KH77U5xJRMEpjHqoT5jxMEU1hYXX92xctnagBMXxDvzUfu3Yf0tvYSA0RRoGbGTHfdYYRwOrGbwQ75Qg1dyIxUkwsG053eYX2XkmLGxymEMgIq_gWksgAamOc40_0OCdGr-MmDeD2HyGUa309aGltzQUw7Z0zG1AKSXy3WwfMHdWNFioTAvQphwEyY3US8ybSJi78upSFTjwUcryMeHUwQ3uV9PxwPMyPoYxo1izVB-OUJxM8RqEbg" +) + +var ( + expiry = jwt.NumericDate(1719506539) + issuedAt = jwt.NumericDate(1719502939) + validParsedToken = oauth.TokenResult{ + AccessToken: validToken, + RefreshToken: "refresh-token", + Claims: oauth.Claims{ + Claims: jwt.Claims{ + Issuer: "https://login.docker.com/", + Subject: "samlp|samlp-docker|bork@docker.com", + Audience: jwt.Audience{ + "https://audience.com", + }, + Expiry: &expiry, + IssuedAt: &issuedAt, + }, + Domain: oauth.DomainClaims{ + UUID: "0123-456789", + Email: "bork@docker.com", + Username: "bork!", + Source: "samlp", + SessionID: "a-session-id", + }, + Scope: "openid offline_access", + }, + } +) + +func TestLoginDevice(t *testing.T) { + t.Run("valid token", func(t *testing.T) { + expectedState := api.State{ + DeviceCode: "device-code", + UserCode: "0123-4567", + VerificationURI: "an-url", + ExpiresIn: 300, + } + var receivedAudience string + getDeviceToken := func(audience string) (api.State, error) { + receivedAudience = audience + return expectedState, nil + } + var receivedState api.State + waitForDeviceToken := func(state api.State) (api.TokenResponse, error) { + receivedState = state + return api.TokenResponse{ + AccessToken: validToken, + RefreshToken: "refresh-token", + }, nil + } + api := &testAPI{ + getDeviceToken: getDeviceToken, + waitForDeviceToken: waitForDeviceToken, + } + manager := OAuthManager{ + audience: "https://hub.docker.com", + api: api, + openBrowser: func(url string) error { + return nil + }, + } + + res, err := manager.LoginDevice(context.Background(), os.Stderr) + assert.NilError(t, err) + + assert.Equal(t, receivedAudience, "https://hub.docker.com") + assert.Equal(t, receivedState, expectedState) + assert.DeepEqual(t, res, validParsedToken) + }) + + t.Run("timeout", func(t *testing.T) { + getDeviceToken := func(audience string) (api.State, error) { + return api.State{ + DeviceCode: "device-code", + UserCode: "0123-4567", + VerificationURI: "an-url", + ExpiresIn: 300, + }, nil + } + waitForDeviceToken := func(state api.State) (api.TokenResponse, error) { + return api.TokenResponse{}, api.ErrTimeout + } + a := &testAPI{ + getDeviceToken: getDeviceToken, + waitForDeviceToken: waitForDeviceToken, + } + manager := OAuthManager{ + api: a, + openBrowser: func(url string) error { + return nil + }, + } + + _, err := manager.LoginDevice(context.Background(), os.Stderr) + assert.ErrorContains(t, err, "login failed: timed out waiting for device token") + }) + + // todo(laurazard): test the case where the user cancels the login + t.Run("canceled context", func(t *testing.T) {}) + + t.Run("stores in cred store", func(t *testing.T) { + getDeviceToken := func(audience string) (api.State, error) { + return api.State{ + DeviceCode: "device-code", + UserCode: "0123-4567", + }, nil + } + waitForDeviceToken := func(state api.State) (api.TokenResponse, error) { + return api.TokenResponse{ + AccessToken: validToken, + RefreshToken: "refresh-token", + }, nil + } + a := &testAPI{ + getDeviceToken: getDeviceToken, + waitForDeviceToken: waitForDeviceToken, + } + manager := OAuthManager{ + api: a, + openBrowser: func(url string) error { + return nil + }, + } + + res, err := manager.LoginDevice(context.Background(), os.Stderr) + assert.NilError(t, err) + + assert.Equal(t, res.AccessToken, validToken) + }) +} + +func TestLogout(t *testing.T) { + a := &testAPI{ + logoutURL: "test-logout-url", + } + var browserOpenURL string + manager := OAuthManager{ + api: a, + openBrowser: func(url string) error { + browserOpenURL = url + return nil + }, + } + + err := manager.Logout(context.Background()) + assert.NilError(t, err) + + assert.Equal(t, browserOpenURL, "test-logout-url") +} + +func TestRefreshToken(t *testing.T) { + t.Run("success", func(t *testing.T) { + var receivedRefreshToken string + a := &testAPI{ + refresh: func(token string) (api.TokenResponse, error) { + receivedRefreshToken = token + return api.TokenResponse{ + AccessToken: newerToken, + RefreshToken: "new-refresh-token", + }, nil + }, + } + manager := OAuthManager{ + api: a, + } + + res, err := manager.RefreshToken(context.Background(), "old-refresh-token") + assert.NilError(t, err) + + assert.Equal(t, receivedRefreshToken, "old-refresh-token") + assert.Equal(t, res.AccessToken, newerToken) + }) + + // todo(laurazard): test the case where the user cancels the refresh + t.Run("canceled context", func(t *testing.T) {}) +} + +var _ api.OAuthAPI = &testAPI{} + +type testAPI struct { + logoutURL string + getDeviceToken func(audience string) (api.State, error) + waitForDeviceToken func(state api.State) (api.TokenResponse, error) + refresh func(token string) (api.TokenResponse, error) +} + +func (t *testAPI) GetDeviceCode(_ context.Context, audience string) (api.State, error) { + if t.getDeviceToken != nil { + return t.getDeviceToken(audience) + } + return api.State{}, nil +} + +func (t *testAPI) WaitForDeviceToken(_ context.Context, state api.State) (api.TokenResponse, error) { + if t.waitForDeviceToken != nil { + return t.waitForDeviceToken(state) + } + return api.TokenResponse{}, nil +} + +func (t *testAPI) Refresh(_ context.Context, token string) (api.TokenResponse, error) { + if t.refresh != nil { + return t.refresh(token) + } + return api.TokenResponse{}, nil +} + +func (t *testAPI) LogoutURL() string { + return t.logoutURL +} diff --git a/cli/internal/oauth/manager/util.go b/cli/internal/oauth/manager/util.go new file mode 100644 index 000000000000..27629c8b29ab --- /dev/null +++ b/cli/internal/oauth/manager/util.go @@ -0,0 +1,26 @@ +package manager + +import ( + "fmt" + "runtime" + "strings" + + "github.com/docker/cli/cli/version" +) + +const ( + audience = "https://hub.docker.com" + tenant = "login.docker.com" + clientID = "DHWuMefQ1v4lxENpz8oUYH50yYSwyPvi" +) + +func NewManager() *OAuthManager { + cliVersion := strings.ReplaceAll(version.Version, ".", "_") + options := OAuthManagerOptions{ + Audience: audience, + ClientID: clientID, + Tenant: tenant, + DeviceName: fmt.Sprintf("docker-cli:%s:%s-%s", cliVersion, runtime.GOOS, runtime.GOARCH), + } + return New(options) +} diff --git a/cli/internal/oauth/util/client.go b/cli/internal/oauth/util/client.go new file mode 100644 index 000000000000..3903035c6977 --- /dev/null +++ b/cli/internal/oauth/util/client.go @@ -0,0 +1,50 @@ +package util + +import ( + "io" + "net/http" +) + +// Client is a client and actions for interacting with the tenant auth API. +type Client struct { + UserAgent string +} + +// setHeaders sets common headers for requests. +func (c Client) setHeaders(req *http.Request, isForm bool) { + if isForm { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } +} + +// PostForm does a POST request with form data. +func (c Client) PostForm(url string, data io.Reader) (*http.Response, error) { + client := http.Client{} + + req, err := http.NewRequest(http.MethodPost, url, data) + if err != nil { + return nil, err + } + + c.setHeaders(req, true) + + return client.Do(req) +} + +// Post does a POST with the specified data. +func (c Client) Post(url string, data io.Reader) (*http.Response, error) { + client := http.Client{} + + req, err := http.NewRequest(http.MethodPost, url, data) + if err != nil { + return nil, err + } + + c.setHeaders(req, false) + + return client.Do(req) +} diff --git a/cli/internal/oauth/util/util.go b/cli/internal/oauth/util/util.go new file mode 100644 index 000000000000..c170e77a4371 --- /dev/null +++ b/cli/internal/oauth/util/util.go @@ -0,0 +1,36 @@ +package util + +import ( + "errors" + "os/exec" + "runtime" + "time" + + "github.com/docker/cli/cli/oauth" + "github.com/go-jose/go-jose/v3/jwt" +) + +// IsExpired returns whether the claims are expired or not. +func IsExpired(claims oauth.Claims) bool { + err := claims.Validate(jwt.Expected{ + Time: time.Now().UTC(), + }) + + return errors.Is(err, jwt.ErrExpired) +} + +// OpenBrowser opens the specified URL in a browser based on OS. +func OpenBrowser(url string) (err error) { + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + default: + err = errors.New("unsupported platform") + } + + return +} diff --git a/internal/test/cli.go b/internal/test/cli.go index f84413ba811a..ebd3ffbb16df 100644 --- a/internal/test/cli.go +++ b/internal/test/cli.go @@ -2,6 +2,7 @@ package test import ( "bytes" + "context" "errors" "io" "strings" @@ -11,6 +12,7 @@ import ( "github.com/docker/cli/cli/context/docker" "github.com/docker/cli/cli/context/store" manifeststore "github.com/docker/cli/cli/manifest/store" + "github.com/docker/cli/cli/oauth" registryclient "github.com/docker/cli/cli/registry/client" "github.com/docker/cli/cli/streams" "github.com/docker/cli/cli/trust" @@ -211,3 +213,21 @@ func EnableContentTrust(c *FakeCli) { func (c *FakeCli) BuildKitEnabled() (bool, error) { return true, nil } + +func (c *FakeCli) OAuthManager() oauth.Manager { + return &fakeOauthManager{} +} + +type fakeOauthManager struct{} + +func (f *fakeOauthManager) LoginDevice(ctx context.Context, w io.Writer) (res oauth.TokenResult, err error) { + return res, nil +} + +func (f *fakeOauthManager) Logout(ctx context.Context) error { + return nil +} + +func (f *fakeOauthManager) RefreshToken(ctx context.Context, refreshToken string) (res oauth.TokenResult, err error) { + return res, nil +}