Skip to content

Commit

Permalink
feat: sign out all sessions (#60)
Browse files Browse the repository at this point in the history
## Description

Adding support for `/sign_out_all_sessions`.

/sign_out_all_sessions endpoint will remove the current session and make
a POST request to IAM, configured via
`OAUTH2_PROXY_BACKEND_LOGOUT_ALL_SESSIONS_URL` env, to invalidate all
the tokens and sessions. This will not invalidate other user sessions.

Once the tokens and sessions are invalidated, after the refresh token
period defined on the `OAUTH2_PROXY_COOKIE_REFRESH` env, OAuth will fail
to refresh the access token and clear that session.

related to:
- philips-internal/pics-foundation-envoy#105
- philips-internal/pics#3006


[AB#1579962](https://tfsemea1.ta.philips.com/tfs/TPC_Region11/0839b845-d626-4499-94ae-563a86a88d0a/_workitems/edit/1579962)

## Motivation and Context

Possibility for signing out on all devices.

## How Has This Been Tested?

Integrated locally with PICS by running binary. Docs
[here](https://github.com/philips-internal/pics/blob/main/src/services/Oauth2Proxy/docs/development.md).

## Checklist:

- [x] Add OAUTH2_PROXY_BACKEND_LOGOUT_ALL_SESSIONS_URL env
- [x] Add  /sign_out_all_sessions endpoint
- [x] Remove other user sessions when tokens are invalid
  • Loading branch information
andersonvcv authored Dec 30, 2024
2 parents a9ac8d5 + 982e27f commit 48f13a7
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/docs/configuration/alpha_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ Provider holds all configuration for a single provider
| `allowedGroups` | _[]string_ | AllowedGroups is a list of restrict logins to members of this group |
| `code_challenge_method` | _string_ | The code challenge method |
| `backendLogoutURL` | _string_ | URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session |
| `backendLogoutAllSessionsURL` | _string_ | URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims |
### ProviderType
#### (`string` alias)
Expand Down
62 changes: 44 additions & 18 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,16 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) {

// The userinfo and logout endpoints needs to load sessions before handling the request
s.Path(userInfoPath).Handler(p.sessionChain.ThenFunc(p.UserInfo))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(p.SignOut))
s.Path(signOutPath).Handler(p.sessionChain.ThenFunc(
func(w http.ResponseWriter, r *http.Request) {
p.SignOut(w, r, false)
},
))
s.Path(picsSignOutAllDevicesPath).Handler(p.sessionChain.ThenFunc(
func(w http.ResponseWriter, r *http.Request) {
p.SignOut(w, r, true)
},
))
}

// buildPreAuthChain constructs a chain that should process every request before
Expand Down Expand Up @@ -758,7 +767,7 @@ func (p *OAuthProxy) UserInfo(rw http.ResponseWriter, req *http.Request) {
}

// SignOut sends a response to clear the authentication cookie
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) {
redirect, err := p.appDirector.GetRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
Expand All @@ -772,12 +781,12 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
return
}

p.backendLogout(rw, req)
p.backendLogout(rw, req, signOutAllSessions)

http.Redirect(rw, req, redirect, http.StatusFound)
}

func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request, signOutAllSessions bool) {
session, err := p.getAuthenticatedSession(rw, req)
if err != nil {
logger.Errorf("error getting authenticated session during backend logout: %v", err)
Expand All @@ -789,22 +798,39 @@ func (p *OAuthProxy) backendLogout(rw http.ResponseWriter, req *http.Request) {
}

providerData := p.provider.Data()
if providerData.BackendLogoutURL == "" {
return
}
var resp *http.Response
if signOutAllSessions {
if providerData.BackendLogoutAllSessionsURL == "" {
return
}

backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken)
// security exception because URL is dynamic ({id_token} replacement) but
// base is not end-user provided but comes from configuration somewhat secure
resp, err := http.Get(backendLogoutURL) // #nosec G107
if err != nil {
logger.Errorf("error while calling backend logout: %v", err)
return
}
resp, err := PicsSignOutAllSessions(providerData.BackendLogoutAllSessionsURL, session.IntrospectClaims, session.AccessToken)
if err != nil {
logger.Errorf("error while calling backend logout all sessions: %v", err)
return
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode)
if resp.StatusCode() != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode())
}
} else {
if providerData.BackendLogoutURL == "" {
return
}

backendLogoutURL := strings.ReplaceAll(providerData.BackendLogoutURL, "{id_token}", session.IDToken)
// security exception because URL is dynamic ({id_token} replacement) but
// base is not end-user provided but comes from configuration somewhat secure
resp, err = http.Get(backendLogoutURL) // #nosec G107
if err != nil {
logger.Errorf("error while calling backend logout: %v", err)
return
}

defer resp.Body.Close()
if resp.StatusCode != 200 {
logger.Errorf("error while calling backend logout url, returned error code %v", resp.StatusCode)
}
}
}

Expand Down
59 changes: 59 additions & 0 deletions pics_oauthproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests"
)

const (
picsSignOutAllDevicesPath = "/sign_out_all_sessions"
)

func PicsSignOutAllSessions(backendLogoutAllSessionsURL string, introspectClaims string, accessToken string) (resp requests.Result, err error) {
userID, err := getUserID(introspectClaims)
if err != nil {
return nil, fmt.Errorf("error getting userID from instrospect claims: %v", err)
}

backendLogoutURL := strings.ReplaceAll(backendLogoutAllSessionsURL, "{user_id}", userID)
resp = requests.New(backendLogoutURL).
WithMethod("POST").
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("API-Version", "1").
SetHeader("Accept", "application/json").
Do()

if resp.Error() != nil {
return nil, fmt.Errorf("error logging out from IAM: %v", err)
}

return resp, err
}

func getUserID(introspectClaims string) (string, error) {
decodedClaims, err := base64.StdEncoding.DecodeString(introspectClaims)
if err != nil {
logger.Errorf("error decoding claims: %v", err)
return "", err
}

var claims map[string]interface{}
err = json.Unmarshal(decodedClaims, &claims)
if err != nil {
logger.Errorf("error unmarshalling claims: %v", err)
return "", err
}

userID, ok := claims["sub"].(string)
if !ok {
logger.Errorf("error extracting 'sub' from claims")
return "", err
}

return userID, nil
}
55 changes: 55 additions & 0 deletions pics_oauthproxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package main

import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func createIntrospectClaims() string {
claims := map[string]interface{}{
"sub": "1234567890",
}
claimsBytes, err := json.Marshal(claims)
if err != nil {
return ""
}

return base64.StdEncoding.EncodeToString(claimsBytes)
}

func Test_PicsSignOutAllSessionsReturnsErrorWhenUserIDIsNotFound(t *testing.T) {
_, err := PicsSignOutAllSessions("http://localhost:8080/test", "", "")

assert.Error(t, err)
}

func Test_getUserID(t *testing.T) {
introspectClaims := createIntrospectClaims()
userID, err := getUserID(introspectClaims)

assert.NoError(t, err)
assert.Equal(t, "1234567890", userID)
}

func Test_PicsSignOutAllSessionsReturns200Ok(t *testing.T) {
introspectClaims := createIntrospectClaims()
accessToken := "validAccessToken"

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Bearer "+accessToken, r.Header.Get("Authorization"))
assert.Equal(t, "1", r.Header.Get("API-Version"))
assert.Equal(t, "application/json", r.Header.Get("Accept"))
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

resp, err := PicsSignOutAllSessions(server.URL+"/{user_id}", introspectClaims, accessToken)

assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode())
}
5 changes: 5 additions & 0 deletions pkg/apis/options/legacy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,8 @@ type LegacyProvider struct {
AllowedRoles []string `flag:"allowed-role" cfg:"allowed_roles"`
BackendLogoutURL string `flag:"backend-logout-url" cfg:"backend_logout_url"`

BackendLogoutAllSessionsURL string `flag:"backend-logout-all-sessions-url" cfg:"backend_logout_all_sessions_url"`

AcrValues string `flag:"acr-values" cfg:"acr_values"`
JWTKey string `flag:"jwt-key" cfg:"jwt_key"`
JWTKeyFile string `flag:"jwt-key-file" cfg:"jwt_key_file"`
Expand Down Expand Up @@ -613,6 +615,7 @@ func legacyProviderFlagSet() *pflag.FlagSet {
flagSet.StringSlice("allowed-group", []string{}, "restrict logins to members of this group (may be given multiple times)")
flagSet.StringSlice("allowed-role", []string{}, "(keycloak-oidc) restrict logins to members of these roles (may be given multiple times)")
flagSet.String("backend-logout-url", "", "url to perform a backend logout, {id_token} can be used as placeholder for the id_token")
flagSet.String("backend-logout-all-sessions-url", "", "url to perform a backend logout, {user_id} can be used as placeholder for the user_id")

return flagSet
}
Expand Down Expand Up @@ -693,6 +696,8 @@ func (l *LegacyProvider) convert() (Providers, error) {
AllowedGroups: l.AllowedGroups,
CodeChallengeMethod: l.CodeChallengeMethod,
BackendLogoutURL: l.BackendLogoutURL,

BackendLogoutAllSessionsURL: l.BackendLogoutAllSessionsURL,
}

// This part is out of the switch section for all providers that support OIDC
Expand Down
3 changes: 3 additions & 0 deletions pkg/apis/options/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ type Provider struct {

// URL to call to perform backend logout, `{id_token}` would be replaced by the actual `id_token` if available in the session
BackendLogoutURL string `json:"backendLogoutURL"`

// URL to call to perform backend logout, `{user_id}` would be replaced by the actual `user_id` if available in the session IntrospectClaims
BackendLogoutAllSessionsURL string `json:"backendLogoutAllSessionsURL"`
}

// ProviderType is used to enumerate the different provider type options
Expand Down
3 changes: 2 additions & 1 deletion pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req
if err := s.refreshSession(rw, req, session); err != nil {
// If a preemptive refresh fails, we still keep the session
// if validateSession succeeds.
logger.Errorf("Unable to refresh session: %v", err)
// PICS: We will clean the session if the refresh fails.
return fmt.Errorf("unable to refresh session: %v", err)
}

// Validate all sessions after any Redeem/Refresh operation (fail or success)
Expand Down
9 changes: 2 additions & 7 deletions pkg/middleware/stored_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,12 @@ var _ = Describe("Stored Session Suite", func() {
refreshSession: defaultRefreshFunc,
validateSession: defaultValidateFunc,
}),
Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{
Entry("when the provider refresh fails", storedSessionLoaderTableInput{
requestHeaders: http.Header{
"Cookie": []string{"_oauth2_proxy=RefreshError"},
},
existingSession: nil,
expectedSession: &sessionsapi.SessionState{
RefreshToken: "RefreshError",
CreatedAt: &createdPast,
ExpiresOn: &createdFuture,
Lock: &sessionsapi.NoOpLock{},
},
expectedSession: nil,
store: defaultSessionStore,
refreshPeriod: 1 * time.Minute,
refreshSession: defaultRefreshFunc,
Expand Down
2 changes: 2 additions & 0 deletions providers/provider_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ type ProviderData struct {
loginURLParameterOverrides map[string]*regexp.Regexp

BackendLogoutURL string

BackendLogoutAllSessionsURL string
}

// Data returns the ProviderData
Expand Down
1 change: 1 addition & 0 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData,
p.setAllowedGroups(providerConfig.AllowedGroups)

p.BackendLogoutURL = providerConfig.BackendLogoutURL
p.BackendLogoutAllSessionsURL = providerConfig.BackendLogoutAllSessionsURL

return p, nil
}
Expand Down

0 comments on commit 48f13a7

Please sign in to comment.