Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Commit

Permalink
refactor access token expiration storage
Browse files Browse the repository at this point in the history
  • Loading branch information
nacx committed Feb 14, 2024
1 parent 6772f8e commit ebb5e28
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 48 deletions.
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ coverage: ## Creates coverage report for all projects
e2e: ## Runt he e2e tests
@$(MAKE) -C e2e e2e

e2e/%:
e2e/%: force-e2e
@$(MAKE) -C e2e $(@)

.PHONY: force-e2e
force-e2e:

##@ Docker targets

.PHONY: docker-pre
Expand Down
10 changes: 8 additions & 2 deletions e2e/redis/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ func TestRedisTokenResponse(t *testing.T) {

// Create a session and verify it's added and accessed time
tr = &oidc.TokenResponse{
IDToken: newToken(),
AccessToken: newToken(),
IDToken: newToken(),
AccessToken: newToken(),
AccessTokenExpiresAt: time.Now().Add(30 * time.Minute),
}
require.NoError(t, store.SetTokenResponse(ctx, "s1", tr))

// Verify we can retrieve the token
got, err := store.GetTokenResponse(ctx, "s1")
require.NoError(t, err)
// The testify library doesn't properly compare times, so we need to do it manually
// then set the times in the returned object so that we can compare the rest of the
// fields normally
require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt))
got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt
require.Equal(t, tr, got)

// Verify that the token TTL has been set
Expand Down
43 changes: 27 additions & 16 deletions internal/oidc/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ var (
)

const (
keyIDToken = "id_token"
keyAccessToken = "access_token"
keyRefreshToken = "refresh_token"
keyState = "state"
keyNonce = "nonce"
keyRequestedURL = "requested_url"
keyTimeAdded = "time_added"
keyIDToken = "id_token"
keyAccessToken = "access_token"
keyAccessTokenExpiry = "access_token_expiry"
keyRefreshToken = "refresh_token"
keyState = "state"
keyNonce = "nonce"
keyRequestedURL = "requested_url"
keyTimeAdded = "time_added"
)

var (
tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyTimeAdded}
tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyAccessTokenExpiry, keyTimeAdded}
// authorizationStateKeys = []string{keyState, keyNonce, keyRequestedURL, keyTimeAdded}
)

Expand Down Expand Up @@ -89,6 +90,14 @@ func (r *redisStore) SetTokenResponse(ctx context.Context, sessionID string, tok
keysToDelete = append(keysToDelete, keyAccessToken)
}

if !tokenResponse.AccessTokenExpiresAt.IsZero() {
if err := r.client.HSet(ctx, sessionID, keyAccessTokenExpiry, tokenResponse.AccessTokenExpiresAt).Err(); err != nil {
return err
}
} else {
keysToDelete = append(keysToDelete, keyAccessTokenExpiry)
}

if tokenResponse.RefreshToken != "" {
if err := r.client.HSet(ctx, sessionID, keyRefreshToken, tokenResponse.RefreshToken).Err(); err != nil {
return err
Expand Down Expand Up @@ -130,7 +139,7 @@ func (r *redisStore) GetTokenResponse(ctx context.Context, sessionID string) (*T
}

tokenResponse := token.TokenResponse()
if _, err := tokenResponse.GetIDToken(); err != nil {
if _, err := tokenResponse.ParseIDToken(); err != nil {
log.Error("failed to parse id token", err, "session_id", sessionID, "token", token)
return nil, nil
}
Expand Down Expand Up @@ -180,16 +189,18 @@ func (r *redisStore) refreshExpiration(ctx context.Context, sessionID string, ti
}

type redisToken struct {
IDToken string `redis:"id_token"`
AccessToken string `redis:"access_token"`
RefreshToken string `redis:"refresh_token"`
TimeAdded time.Time `redis:"time_added"`
IDToken string `redis:"id_token"`
AccessToken string `redis:"access_token"`
AccessTokenExpiresAt time.Time `redis:"access_token_expiry"`
RefreshToken string `redis:"refresh_token"`
TimeAdded time.Time `redis:"time_added"`
}

func (r redisToken) TokenResponse() TokenResponse {
return TokenResponse{
IDToken: r.IDToken,
AccessToken: r.AccessToken,
RefreshToken: r.RefreshToken,
IDToken: r.IDToken,
AccessToken: r.AccessToken,
AccessTokenExpiresAt: r.AccessTokenExpiresAt,
RefreshToken: r.RefreshToken,
}
}
39 changes: 36 additions & 3 deletions internal/oidc/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,21 @@ func TestRedisTokenResponse(t *testing.T) {

// Create a session and verify it's added and accessed time
tr = &TokenResponse{
IDToken: newToken(),
AccessToken: newToken(),
RefreshToken: newToken(),
IDToken: newToken(),
AccessToken: newToken(),
AccessTokenExpiresAt: time.Now().Add(30 * time.Minute),
RefreshToken: newToken(),
}
require.NoError(t, store.SetTokenResponse(ctx, "s1", tr))

// Verify we can retrieve the token
got, err := store.GetTokenResponse(ctx, "s1")
require.NoError(t, err)
// The testify library doesn't properly compare times, so we need to do it manually
// then set the times in the returned object so that we can compare the rest of the
// fields normally
require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt))
got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt
require.Equal(t, tr, got)

// Verify that the token TTL has been set
Expand Down Expand Up @@ -75,3 +81,30 @@ func TestRedisPingError(t *testing.T) {
_, err := NewRedisStore(&Clock{}, client, 0, 1*time.Minute)
require.EqualError(t, err, "ping error")
}

func TestRefreshExpiration(t *testing.T) {
mr := miniredis.RunT(t)
client := redis.NewClient(&redis.Options{Addr: mr.Addr()})
store, err := NewRedisStore(&Clock{}, client, 0, 0)
require.NoError(t, err)

ctx := context.Background()

t.Run("delete session if no time added", func(t *testing.T) {
require.NoError(t, client.HSet(ctx, "s1", keyAccessToken, "").Err())
err := store.(*redisStore).refreshExpiration(ctx, "s1", time.Time{})
require.ErrorIs(t, err, ErrRedis)
require.Equal(t, redis.Nil, client.Get(ctx, "s1").Err())
})

t.Run("no expiration set if no timeouts", func(t *testing.T) {
require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err())
require.NoError(t, store.(*redisStore).refreshExpiration(ctx, "s1", time.Time{}))

res, err := client.TTL(ctx, "s1").Result()
require.NoError(t, err)
require.Equal(t, time.Duration(-1), res)
})

// TODO(nacx): Expiration is updated
}
13 changes: 7 additions & 6 deletions internal/oidc/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@
package oidc

import (
"time"

"github.com/lestrrat-go/jwx/jwt"
)

// TokenResponse contains information about the tokens returned by the Identity Provider.
type TokenResponse struct {
IDToken string
AccessToken string
RefreshToken string
IDToken string
AccessToken string
AccessTokenExpiresAt time.Time
RefreshToken string
}

func (t *TokenResponse) GetIDToken() (jwt.Token, error) { return parse(t.IDToken) }
func (t *TokenResponse) GetAccessToken() (jwt.Token, error) { return parse(t.AccessToken) }
func (t *TokenResponse) GetRefreshToken() (jwt.Token, error) { return parse(t.RefreshToken) }
func (t *TokenResponse) ParseIDToken() (jwt.Token, error) { return parse(t.IDToken) }

func parse(token string) (jwt.Token, error) {
return jwt.Parse([]byte(token), jwt.WithValidate(false))
Expand Down
24 changes: 4 additions & 20 deletions internal/oidc/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,20 @@ import (
"github.com/stretchr/testify/require"
)

func TestTokenResponse(t *testing.T) {
func TestParseIDToken(t *testing.T) {
t.Run("valid", func(t *testing.T) {
tr := &TokenResponse{
IDToken: newToken(),
AccessToken: newToken(),
RefreshToken: newToken(),
IDToken: newToken(),
}

it, err := tr.GetIDToken()
it, err := tr.ParseIDToken()
require.NoError(t, err)
require.Equal(t, "authservice", it.Issuer())

at, err := tr.GetAccessToken()
require.NoError(t, err)
require.Equal(t, "authservice", at.Issuer())

rt, err := tr.GetRefreshToken()
require.NoError(t, err)
require.Equal(t, "authservice", rt.Issuer())
})

t.Run("invalid", func(t *testing.T) {
tr := &TokenResponse{}
_, err := tr.GetIDToken()
require.Error(t, err)

_, err = tr.GetAccessToken()
require.Error(t, err)

_, err = tr.GetRefreshToken()
_, err := tr.ParseIDToken()
require.Error(t, err)
})
}
Expand Down

0 comments on commit ebb5e28

Please sign in to comment.