Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚧 Add support for device-code flow login (alternative 2) 🚧 #5245

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/command/registry/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err
}

creds := dockerCli.ConfigFile().GetCredentialsStore(serverAddress)

// todo(laurazard): this will no longer trigger even when the store is a file store
store, isDefault := creds.(isFileStore)
// Display a warning if we're storing the users password (not a token)
if isDefault && authConfig.Password != "" {
Expand Down
4 changes: 3 additions & 1 deletion cli/command/registry/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ func TestLoginTermination(t *testing.T) {

runErr := make(chan error)
go func() {
runErr <- runLogin(ctx, cli, loginOptions{})
runErr <- runLogin(ctx, cli, loginOptions{
user: "test-user",
})
}()

// Let the prompt get canceled by the context
Expand Down
7 changes: 5 additions & 2 deletions cli/config/configfile/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,13 @@ func decodeAuth(authStr string) (string, string, error) {
// GetCredentialsStore returns a new credentials store from the settings in the
// configuration file
func (configFile *ConfigFile) GetCredentialsStore(registryHostname string) credentials.Store {
var credsStore credentials.Store
if helper := getConfiguredCredentialStore(configFile, registryHostname); helper != "" {
return newNativeStore(configFile, helper)
credsStore = newNativeStore(configFile, helper)
} else {
credsStore = credentials.NewFileStore(configFile)
}
return credentials.NewFileStore(configFile)
return credentials.NewOAuthStore(credsStore)
}

// var for unit testing.
Expand Down
153 changes: 153 additions & 0 deletions cli/config/credentials/internal/oauth/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/docker/cli/cli/config/credentials/internal/oauth/util"
)

type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
Refresh(ctx context.Context, token string) (TokenResponse, error)
LogoutURL() string
}

// API represents API interactions with Auth0.
type API struct {
// BaseURL is the base used for each request to Auth0.
BaseURL string
// ClientID is the client ID for the application to auth with the tenant.
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
// Client is the client that is used for calls.
Client util.Client
}

// TokenResponse represents the response of the /oauth/token route.
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Error *string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}

var ErrTimeout = errors.New("timed out waiting for device token")

// GetDeviceCode returns device code authorization information from Auth0.
func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) {
data := url.Values{
"client_id": {a.ClientID},
"audience": {audience},
"scope": {strings.Join(a.Scopes, " ")},
}

deviceCodeURL := a.BaseURL + "/oauth/device/code"
resp, err := a.Client.PostForm(deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
var body map[string]any
err = json.NewDecoder(resp.Body).Decode(&body)
if errorDescription, ok := body["error_description"].(string); ok {
return state, errors.New(errorDescription)
}
return state, fmt.Errorf("failed to get device code: %w", err)
}

err = json.NewDecoder(resp.Body).Decode(&state)

return
}

// WaitForDeviceToken polls to get tokens based on the device code set up. This
// only works in a device auth flow.
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ticker := time.NewTicker(state.IntervalDuration())
timeout := time.After(time.Duration(state.ExpiresIn) * time.Second)

for {
select {
case <-ticker.C:
res, err := a.getDeviceToken(state)
if err != nil {
return res, err
}

if res.Error != nil {
if *res.Error == "authorization_pending" {
continue
}

return res, errors.New(res.ErrorDescription)
}

return res, nil
case <-timeout:
ticker.Stop()
return TokenResponse{}, ErrTimeout
}
}
}

// getToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(state State) (res TokenResponse, err error) {
data := url.Values{
"client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {state.DeviceCode},
}
oauthTokenURL := a.BaseURL + "/oauth/token"

resp, err := a.Client.PostForm(oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return res, fmt.Errorf("failed to get code: %w", err)
}

err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()

return
}

// Refresh returns new tokens based on the refresh token.
func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err error) {
data := url.Values{
"grant_type": {"refresh_token"},
"client_id": {a.ClientID},
"refresh_token": {token},
}

refreshURL := a.BaseURL + "/oauth/token"
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url
resp, err := http.PostForm(refreshURL, data)
if err != nil {
return
}

err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()

return
}

func (a API) LogoutURL() string {
return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID)
}
188 changes: 188 additions & 0 deletions cli/config/credentials/internal/oauth/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/docker/cli/cli/config/credentials/internal/oauth/util"
"gotest.tools/v3/assert"
)

func TestGetDeviceCode(t *testing.T) {
t.Run("success", func(t *testing.T) {
var clientID, audience, scope, path string
expectedState := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
VerificationURI: "aVerificationURI",
ExpiresIn: 60,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
clientID = r.FormValue("client_id")
audience = r.FormValue("audience")
scope = r.FormValue("scope")
path = r.URL.Path

jsonState, err := json.Marshal(expectedState)
assert.NilError(t, err)

_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

state, err := api.GetDeviceCode(context.Background(), "anAudience")
assert.NilError(t, err)

assert.DeepEqual(t, expectedState, state)
assert.Equal(t, clientID, "aClientID")
assert.Equal(t, audience, "anAudience")
assert.Equal(t, scope, "bork meow")
assert.Equal(t, path, "/oauth/device/code")
})

t.Run("error w/ description", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonState, err := json.Marshal(TokenResponse{
ErrorDescription: "invalid audience",
})
assert.NilError(t, err)

w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "bad_audience")

assert.ErrorContains(t, err, "invalid audience")
})

t.Run("general error", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "an error", http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "anAudience")

assert.ErrorContains(t, err, "failed to get device code")
})

// todo(laurazard): test
t.Run("canceled context", func(t *testing.T) {})
}

func TestWaitForDeviceToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
expectedToken := TokenResponse{
AccessToken: "a-real-token",
IDToken: "",
RefreshToken: "the-refresh-token",
Scope: "",
ExpiresIn: 3600,
TokenType: "",
}
var respond bool
go func() {
time.Sleep(10 * time.Second)
respond = true
}()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")

if respond {
jsonState, err := json.Marshal(expectedToken)
assert.NilError(t, err)
w.Write(jsonState)
} else {
pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 30,
}
token, err := api.WaitForDeviceToken(context.Background(), state)
assert.NilError(t, err)

assert.DeepEqual(t, token, expectedToken)
})

t.Run("timeout", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")

pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 1,
}

_, err := api.WaitForDeviceToken(context.Background(), state)

assert.ErrorIs(t, err, ErrTimeout)
})

// todo(laurazard): test
t.Run("canceled context", func(t *testing.T) {})
}
20 changes: 20 additions & 0 deletions cli/config/credentials/internal/oauth/api/state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package api

import (
"time"
)

// State represents the state of exchange after submitting.
type State struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}

// IntervalDuration returns the duration that should be waited between each auth
// polling event.
func (s State) IntervalDuration() time.Duration {
return time.Second * time.Duration(s.Interval)
}
Loading
Loading