Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: Laura Brehm <[email protected]>
  • Loading branch information
laurazard committed Jul 8, 2024
1 parent 6abed4e commit 7d1a47d
Show file tree
Hide file tree
Showing 55 changed files with 11,737 additions and 6 deletions.
2 changes: 1 addition & 1 deletion cli/command/registry/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func runLogin(ctx context.Context, dockerCli command.Cli, opts loginOptions) err
}

creds := dockerCli.ConfigFile().GetCredentialsStore(serverAddress)

// todo(laurazard): this will no longer trigger even when the store is a file store
store, isDefault := creds.(isFileStore)
// Display a warning if we're storing the users password (not a token)
if isDefault && authConfig.Password != "" {
Expand Down
4 changes: 3 additions & 1 deletion cli/command/registry/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ func TestLoginTermination(t *testing.T) {

runErr := make(chan error)
go func() {
runErr <- runLogin(ctx, cli, loginOptions{})
runErr <- runLogin(ctx, cli, loginOptions{
user: "test-user",
})
}()

// Let the prompt get canceled by the context
Expand Down
7 changes: 5 additions & 2 deletions cli/config/configfile/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,13 @@ func decodeAuth(authStr string) (string, string, error) {
// GetCredentialsStore returns a new credentials store from the settings in the
// configuration file
func (configFile *ConfigFile) GetCredentialsStore(registryHostname string) credentials.Store {
var credsStore credentials.Store
if helper := getConfiguredCredentialStore(configFile, registryHostname); helper != "" {
return newNativeStore(configFile, helper)
credsStore = newNativeStore(configFile, helper)
} else {
credsStore = credentials.NewFileStore(configFile)
}
return credentials.NewFileStore(configFile)
return credentials.NewOAuthStore(credsStore)
}

// var for unit testing.
Expand Down
153 changes: 153 additions & 0 deletions cli/config/credentials/internal/oauth/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package api

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/docker/cli/cli/config/credentials/internal/oauth/util"
)

type OAuthAPI interface {
GetDeviceCode(ctx context.Context, audience string) (State, error)
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error)
Refresh(ctx context.Context, token string) (TokenResponse, error)
LogoutURL() string
}

// API represents API interactions with Auth0.
type API struct {
// BaseURL is the base used for each request to Auth0.
BaseURL string
// ClientID is the client ID for the application to auth with the tenant.
ClientID string
// Scopes are the scopes that are requested during the device auth flow.
Scopes []string
// Client is the client that is used for calls.
Client util.Client
}

// TokenResponse represents the response of the /oauth/token route.
type TokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Error *string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
}

var ErrTimeout = errors.New("timed out waiting for device token")

// GetDeviceCode returns device code authorization information from Auth0.
func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) {
data := url.Values{
"client_id": {a.ClientID},
"audience": {audience},
"scope": {strings.Join(a.Scopes, " ")},
}

deviceCodeURL := a.BaseURL + "/oauth/device/code"
resp, err := a.Client.PostForm(deviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return
}
defer func() {
_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
var body map[string]any
err = json.NewDecoder(resp.Body).Decode(&body)
if errorDescription, ok := body["error_description"].(string); ok {
return state, errors.New(errorDescription)
}
return state, fmt.Errorf("failed to get device code: %w", err)
}

err = json.NewDecoder(resp.Body).Decode(&state)

return
}

// WaitForDeviceToken polls to get tokens based on the device code set up. This
// only works in a device auth flow.
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) {
ticker := time.NewTicker(state.IntervalDuration())
timeout := time.After(time.Duration(state.ExpiresIn) * time.Second)

for {
select {
case <-ticker.C:
res, err := a.getDeviceToken(state)
if err != nil {
return res, err
}

if res.Error != nil {
if *res.Error == "authorization_pending" {
continue
}

return res, errors.New(res.ErrorDescription)
}

return res, nil
case <-timeout:
ticker.Stop()
return TokenResponse{}, ErrTimeout
}
}
}

// getToken calls the token endpoint of Auth0 and returns the response.
func (a API) getDeviceToken(state State) (res TokenResponse, err error) {
data := url.Values{
"client_id": {a.ClientID},
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"device_code": {state.DeviceCode},
}
oauthTokenURL := a.BaseURL + "/oauth/token"

resp, err := a.Client.PostForm(oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return res, fmt.Errorf("failed to get code: %w", err)
}

err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()

return
}

// Refresh returns new tokens based on the refresh token.
func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err error) {
data := url.Values{
"grant_type": {"refresh_token"},
"client_id": {a.ClientID},
"refresh_token": {token},
}

refreshURL := a.BaseURL + "/oauth/token"
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url
resp, err := http.PostForm(refreshURL, data)
if err != nil {
return
}

err = json.NewDecoder(resp.Body).Decode(&res)
_ = resp.Body.Close()

return
}

func (a API) LogoutURL() string {
return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID)
}
188 changes: 188 additions & 0 deletions cli/config/credentials/internal/oauth/api/api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/docker/cli/cli/config/credentials/internal/oauth/util"
"gotest.tools/v3/assert"
)

func TestGetDeviceCode(t *testing.T) {
t.Run("success", func(t *testing.T) {
var clientID, audience, scope, path string
expectedState := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
VerificationURI: "aVerificationURI",
ExpiresIn: 60,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
clientID = r.FormValue("client_id")
audience = r.FormValue("audience")
scope = r.FormValue("scope")
path = r.URL.Path

jsonState, err := json.Marshal(expectedState)
assert.NilError(t, err)

_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

state, err := api.GetDeviceCode(context.Background(), "anAudience")
assert.NilError(t, err)

assert.DeepEqual(t, expectedState, state)
assert.Equal(t, clientID, "aClientID")
assert.Equal(t, audience, "anAudience")
assert.Equal(t, scope, "bork meow")
assert.Equal(t, path, "/oauth/device/code")
})

t.Run("error w/ description", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jsonState, err := json.Marshal(TokenResponse{
ErrorDescription: "invalid audience",
})
assert.NilError(t, err)

w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(jsonState)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "bad_audience")

assert.ErrorContains(t, err, "invalid audience")
})

t.Run("general error", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "an error", http.StatusInternalServerError)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}

_, err := api.GetDeviceCode(context.Background(), "anAudience")

assert.ErrorContains(t, err, "failed to get device code")
})

// todo(laurazard): test
t.Run("canceled context", func(t *testing.T) {})
}

func TestWaitForDeviceToken(t *testing.T) {
t.Run("success", func(t *testing.T) {
expectedToken := TokenResponse{
AccessToken: "a-real-token",
IDToken: "",
RefreshToken: "the-refresh-token",
Scope: "",
ExpiresIn: 3600,
TokenType: "",
}
var respond bool
go func() {
time.Sleep(10 * time.Second)
respond = true
}()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")

if respond {
jsonState, err := json.Marshal(expectedToken)
assert.NilError(t, err)
w.Write(jsonState)
} else {
pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 30,
}
token, err := api.WaitForDeviceToken(context.Background(), state)
assert.NilError(t, err)

assert.DeepEqual(t, token, expectedToken)
})

t.Run("timeout", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
assert.Equal(t, "/oauth/token", r.URL.Path)
assert.Equal(t, r.FormValue("client_id"), "aClientID")
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code")
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode")

pendingError := "authorization_pending"
jsonResponse, err := json.Marshal(TokenResponse{
Error: &pendingError,
})
assert.NilError(t, err)
w.Write(jsonResponse)
}))
defer ts.Close()
api := API{
BaseURL: ts.URL,
ClientID: "aClientID",
Scopes: []string{"bork", "meow"},
Client: util.Client{},
}
state := State{
DeviceCode: "aDeviceCode",
UserCode: "aUserCode",
Interval: 1,
ExpiresIn: 1,
}

_, err := api.WaitForDeviceToken(context.Background(), state)

assert.ErrorIs(t, err, ErrTimeout)
})

// todo(laurazard): test
t.Run("canceled context", func(t *testing.T) {})
}
20 changes: 20 additions & 0 deletions cli/config/credentials/internal/oauth/api/state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package api

import (
"time"
)

// State represents the state of exchange after submitting.
type State struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationURI string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}

// IntervalDuration returns the duration that should be waited between each auth
// polling event.
func (s State) IntervalDuration() time.Duration {
return time.Second * time.Duration(s.Interval)
}
Loading

0 comments on commit 7d1a47d

Please sign in to comment.