-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Laura Brehm <[email protected]>
- Loading branch information
Showing
55 changed files
with
11,737 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package api | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"time" | ||
|
||
"github.com/docker/cli/cli/config/credentials/internal/oauth/util" | ||
) | ||
|
||
type OAuthAPI interface { | ||
GetDeviceCode(ctx context.Context, audience string) (State, error) | ||
WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) | ||
Refresh(ctx context.Context, token string) (TokenResponse, error) | ||
LogoutURL() string | ||
} | ||
|
||
// API represents API interactions with Auth0. | ||
type API struct { | ||
// BaseURL is the base used for each request to Auth0. | ||
BaseURL string | ||
// ClientID is the client ID for the application to auth with the tenant. | ||
ClientID string | ||
// Scopes are the scopes that are requested during the device auth flow. | ||
Scopes []string | ||
// Client is the client that is used for calls. | ||
Client util.Client | ||
} | ||
|
||
// TokenResponse represents the response of the /oauth/token route. | ||
type TokenResponse struct { | ||
AccessToken string `json:"access_token"` | ||
IDToken string `json:"id_token"` | ||
RefreshToken string `json:"refresh_token"` | ||
Scope string `json:"scope"` | ||
ExpiresIn int `json:"expires_in"` | ||
TokenType string `json:"token_type"` | ||
Error *string `json:"error,omitempty"` | ||
ErrorDescription string `json:"error_description,omitempty"` | ||
} | ||
|
||
var ErrTimeout = errors.New("timed out waiting for device token") | ||
|
||
// GetDeviceCode returns device code authorization information from Auth0. | ||
func (a API) GetDeviceCode(ctx context.Context, audience string) (state State, err error) { | ||
data := url.Values{ | ||
"client_id": {a.ClientID}, | ||
"audience": {audience}, | ||
"scope": {strings.Join(a.Scopes, " ")}, | ||
} | ||
|
||
deviceCodeURL := a.BaseURL + "/oauth/device/code" | ||
resp, err := a.Client.PostForm(deviceCodeURL, strings.NewReader(data.Encode())) | ||
if err != nil { | ||
return | ||
} | ||
defer func() { | ||
_ = resp.Body.Close() | ||
}() | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
var body map[string]any | ||
err = json.NewDecoder(resp.Body).Decode(&body) | ||
if errorDescription, ok := body["error_description"].(string); ok { | ||
return state, errors.New(errorDescription) | ||
} | ||
return state, fmt.Errorf("failed to get device code: %w", err) | ||
} | ||
|
||
err = json.NewDecoder(resp.Body).Decode(&state) | ||
|
||
return | ||
} | ||
|
||
// WaitForDeviceToken polls to get tokens based on the device code set up. This | ||
// only works in a device auth flow. | ||
func (a API) WaitForDeviceToken(ctx context.Context, state State) (TokenResponse, error) { | ||
ticker := time.NewTicker(state.IntervalDuration()) | ||
timeout := time.After(time.Duration(state.ExpiresIn) * time.Second) | ||
|
||
for { | ||
select { | ||
case <-ticker.C: | ||
res, err := a.getDeviceToken(state) | ||
if err != nil { | ||
return res, err | ||
} | ||
|
||
if res.Error != nil { | ||
if *res.Error == "authorization_pending" { | ||
continue | ||
} | ||
|
||
return res, errors.New(res.ErrorDescription) | ||
} | ||
|
||
return res, nil | ||
case <-timeout: | ||
ticker.Stop() | ||
return TokenResponse{}, ErrTimeout | ||
} | ||
} | ||
} | ||
|
||
// getToken calls the token endpoint of Auth0 and returns the response. | ||
func (a API) getDeviceToken(state State) (res TokenResponse, err error) { | ||
data := url.Values{ | ||
"client_id": {a.ClientID}, | ||
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, | ||
"device_code": {state.DeviceCode}, | ||
} | ||
oauthTokenURL := a.BaseURL + "/oauth/token" | ||
|
||
resp, err := a.Client.PostForm(oauthTokenURL, strings.NewReader(data.Encode())) | ||
if err != nil { | ||
return res, fmt.Errorf("failed to get code: %w", err) | ||
} | ||
|
||
err = json.NewDecoder(resp.Body).Decode(&res) | ||
_ = resp.Body.Close() | ||
|
||
return | ||
} | ||
|
||
// Refresh returns new tokens based on the refresh token. | ||
func (a API) Refresh(ctx context.Context, token string) (res TokenResponse, err error) { | ||
data := url.Values{ | ||
"grant_type": {"refresh_token"}, | ||
"client_id": {a.ClientID}, | ||
"refresh_token": {token}, | ||
} | ||
|
||
refreshURL := a.BaseURL + "/oauth/token" | ||
//nolint:gosec // Ignore G107: Potential HTTP request made with variable url | ||
resp, err := http.PostForm(refreshURL, data) | ||
if err != nil { | ||
return | ||
} | ||
|
||
err = json.NewDecoder(resp.Body).Decode(&res) | ||
_ = resp.Body.Close() | ||
|
||
return | ||
} | ||
|
||
func (a API) LogoutURL() string { | ||
return fmt.Sprintf("%s/v2/logout?client_id=%s", a.BaseURL, a.ClientID) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
package api | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
"time" | ||
|
||
"github.com/docker/cli/cli/config/credentials/internal/oauth/util" | ||
"gotest.tools/v3/assert" | ||
) | ||
|
||
func TestGetDeviceCode(t *testing.T) { | ||
t.Run("success", func(t *testing.T) { | ||
var clientID, audience, scope, path string | ||
expectedState := State{ | ||
DeviceCode: "aDeviceCode", | ||
UserCode: "aUserCode", | ||
VerificationURI: "aVerificationURI", | ||
ExpiresIn: 60, | ||
} | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
r.ParseForm() | ||
clientID = r.FormValue("client_id") | ||
audience = r.FormValue("audience") | ||
scope = r.FormValue("scope") | ||
path = r.URL.Path | ||
|
||
jsonState, err := json.Marshal(expectedState) | ||
assert.NilError(t, err) | ||
|
||
_, _ = w.Write(jsonState) | ||
})) | ||
defer ts.Close() | ||
api := API{ | ||
BaseURL: ts.URL, | ||
ClientID: "aClientID", | ||
Scopes: []string{"bork", "meow"}, | ||
Client: util.Client{}, | ||
} | ||
|
||
state, err := api.GetDeviceCode(context.Background(), "anAudience") | ||
assert.NilError(t, err) | ||
|
||
assert.DeepEqual(t, expectedState, state) | ||
assert.Equal(t, clientID, "aClientID") | ||
assert.Equal(t, audience, "anAudience") | ||
assert.Equal(t, scope, "bork meow") | ||
assert.Equal(t, path, "/oauth/device/code") | ||
}) | ||
|
||
t.Run("error w/ description", func(t *testing.T) { | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
jsonState, err := json.Marshal(TokenResponse{ | ||
ErrorDescription: "invalid audience", | ||
}) | ||
assert.NilError(t, err) | ||
|
||
w.WriteHeader(http.StatusBadRequest) | ||
_, _ = w.Write(jsonState) | ||
})) | ||
defer ts.Close() | ||
api := API{ | ||
BaseURL: ts.URL, | ||
ClientID: "aClientID", | ||
Scopes: []string{"bork", "meow"}, | ||
Client: util.Client{}, | ||
} | ||
|
||
_, err := api.GetDeviceCode(context.Background(), "bad_audience") | ||
|
||
assert.ErrorContains(t, err, "invalid audience") | ||
}) | ||
|
||
t.Run("general error", func(t *testing.T) { | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
http.Error(w, "an error", http.StatusInternalServerError) | ||
})) | ||
defer ts.Close() | ||
api := API{ | ||
BaseURL: ts.URL, | ||
ClientID: "aClientID", | ||
Scopes: []string{"bork", "meow"}, | ||
Client: util.Client{}, | ||
} | ||
|
||
_, err := api.GetDeviceCode(context.Background(), "anAudience") | ||
|
||
assert.ErrorContains(t, err, "failed to get device code") | ||
}) | ||
|
||
// todo(laurazard): test | ||
t.Run("canceled context", func(t *testing.T) {}) | ||
} | ||
|
||
func TestWaitForDeviceToken(t *testing.T) { | ||
t.Run("success", func(t *testing.T) { | ||
expectedToken := TokenResponse{ | ||
AccessToken: "a-real-token", | ||
IDToken: "", | ||
RefreshToken: "the-refresh-token", | ||
Scope: "", | ||
ExpiresIn: 3600, | ||
TokenType: "", | ||
} | ||
var respond bool | ||
go func() { | ||
time.Sleep(10 * time.Second) | ||
respond = true | ||
}() | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Equal(t, "POST", r.Method) | ||
assert.Equal(t, "/oauth/token", r.URL.Path) | ||
assert.Equal(t, r.FormValue("client_id"), "aClientID") | ||
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code") | ||
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode") | ||
|
||
if respond { | ||
jsonState, err := json.Marshal(expectedToken) | ||
assert.NilError(t, err) | ||
w.Write(jsonState) | ||
} else { | ||
pendingError := "authorization_pending" | ||
jsonResponse, err := json.Marshal(TokenResponse{ | ||
Error: &pendingError, | ||
}) | ||
assert.NilError(t, err) | ||
w.Write(jsonResponse) | ||
} | ||
})) | ||
defer ts.Close() | ||
api := API{ | ||
BaseURL: ts.URL, | ||
ClientID: "aClientID", | ||
Scopes: []string{"bork", "meow"}, | ||
Client: util.Client{}, | ||
} | ||
state := State{ | ||
DeviceCode: "aDeviceCode", | ||
UserCode: "aUserCode", | ||
Interval: 1, | ||
ExpiresIn: 30, | ||
} | ||
token, err := api.WaitForDeviceToken(context.Background(), state) | ||
assert.NilError(t, err) | ||
|
||
assert.DeepEqual(t, token, expectedToken) | ||
}) | ||
|
||
t.Run("timeout", func(t *testing.T) { | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Equal(t, "POST", r.Method) | ||
assert.Equal(t, "/oauth/token", r.URL.Path) | ||
assert.Equal(t, r.FormValue("client_id"), "aClientID") | ||
assert.Equal(t, r.FormValue("grant_type"), "urn:ietf:params:oauth:grant-type:device_code") | ||
assert.Equal(t, r.FormValue("device_code"), "aDeviceCode") | ||
|
||
pendingError := "authorization_pending" | ||
jsonResponse, err := json.Marshal(TokenResponse{ | ||
Error: &pendingError, | ||
}) | ||
assert.NilError(t, err) | ||
w.Write(jsonResponse) | ||
})) | ||
defer ts.Close() | ||
api := API{ | ||
BaseURL: ts.URL, | ||
ClientID: "aClientID", | ||
Scopes: []string{"bork", "meow"}, | ||
Client: util.Client{}, | ||
} | ||
state := State{ | ||
DeviceCode: "aDeviceCode", | ||
UserCode: "aUserCode", | ||
Interval: 1, | ||
ExpiresIn: 1, | ||
} | ||
|
||
_, err := api.WaitForDeviceToken(context.Background(), state) | ||
|
||
assert.ErrorIs(t, err, ErrTimeout) | ||
}) | ||
|
||
// todo(laurazard): test | ||
t.Run("canceled context", func(t *testing.T) {}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package api | ||
|
||
import ( | ||
"time" | ||
) | ||
|
||
// State represents the state of exchange after submitting. | ||
type State struct { | ||
DeviceCode string `json:"device_code"` | ||
UserCode string `json:"user_code"` | ||
VerificationURI string `json:"verification_uri_complete"` | ||
ExpiresIn int `json:"expires_in"` | ||
Interval int `json:"interval"` | ||
} | ||
|
||
// IntervalDuration returns the duration that should be waited between each auth | ||
// polling event. | ||
func (s State) IntervalDuration() time.Duration { | ||
return time.Second * time.Duration(s.Interval) | ||
} |
Oops, something went wrong.