From ebb5e2886639f1136e5037c3b988f2721e8b323a Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Wed, 14 Feb 2024 17:37:40 +0100 Subject: [PATCH] refactor access token expiration storage --- Makefile | 5 ++++- e2e/redis/store_test.go | 10 +++++++-- internal/oidc/redis.go | 43 +++++++++++++++++++++++-------------- internal/oidc/redis_test.go | 39 ++++++++++++++++++++++++++++++--- internal/oidc/token.go | 13 +++++------ internal/oidc/token_test.go | 24 ++++----------------- 6 files changed, 86 insertions(+), 48 deletions(-) diff --git a/Makefile b/Makefile index 0c9d7ab..55865c5 100644 --- a/Makefile +++ b/Makefile @@ -118,9 +118,12 @@ coverage: ## Creates coverage report for all projects e2e: ## Runt he e2e tests @$(MAKE) -C e2e e2e -e2e/%: +e2e/%: force-e2e @$(MAKE) -C e2e $(@) +.PHONY: force-e2e +force-e2e: + ##@ Docker targets .PHONY: docker-pre diff --git a/e2e/redis/store_test.go b/e2e/redis/store_test.go index e0005ab..b326ea3 100644 --- a/e2e/redis/store_test.go +++ b/e2e/redis/store_test.go @@ -45,14 +45,20 @@ func TestRedisTokenResponse(t *testing.T) { // Create a session and verify it's added and accessed time tr = &oidc.TokenResponse{ - IDToken: newToken(), - AccessToken: newToken(), + IDToken: newToken(), + AccessToken: newToken(), + AccessTokenExpiresAt: time.Now().Add(30 * time.Minute), } require.NoError(t, store.SetTokenResponse(ctx, "s1", tr)) // Verify we can retrieve the token got, err := store.GetTokenResponse(ctx, "s1") require.NoError(t, err) + // The testify library doesn't properly compare times, so we need to do it manually + // then set the times in the returned object so that we can compare the rest of the + // fields normally + require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt)) + got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt require.Equal(t, tr, got) // Verify that the token TTL has been set diff --git a/internal/oidc/redis.go b/internal/oidc/redis.go index 5e832cc..68bec15 100644 --- a/internal/oidc/redis.go +++ b/internal/oidc/redis.go @@ -33,17 +33,18 @@ var ( ) const ( - keyIDToken = "id_token" - keyAccessToken = "access_token" - keyRefreshToken = "refresh_token" - keyState = "state" - keyNonce = "nonce" - keyRequestedURL = "requested_url" - keyTimeAdded = "time_added" + keyIDToken = "id_token" + keyAccessToken = "access_token" + keyAccessTokenExpiry = "access_token_expiry" + keyRefreshToken = "refresh_token" + keyState = "state" + keyNonce = "nonce" + keyRequestedURL = "requested_url" + keyTimeAdded = "time_added" ) var ( - tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyTimeAdded} + tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyAccessTokenExpiry, keyTimeAdded} // authorizationStateKeys = []string{keyState, keyNonce, keyRequestedURL, keyTimeAdded} ) @@ -89,6 +90,14 @@ func (r *redisStore) SetTokenResponse(ctx context.Context, sessionID string, tok keysToDelete = append(keysToDelete, keyAccessToken) } + if !tokenResponse.AccessTokenExpiresAt.IsZero() { + if err := r.client.HSet(ctx, sessionID, keyAccessTokenExpiry, tokenResponse.AccessTokenExpiresAt).Err(); err != nil { + return err + } + } else { + keysToDelete = append(keysToDelete, keyAccessTokenExpiry) + } + if tokenResponse.RefreshToken != "" { if err := r.client.HSet(ctx, sessionID, keyRefreshToken, tokenResponse.RefreshToken).Err(); err != nil { return err @@ -130,7 +139,7 @@ func (r *redisStore) GetTokenResponse(ctx context.Context, sessionID string) (*T } tokenResponse := token.TokenResponse() - if _, err := tokenResponse.GetIDToken(); err != nil { + if _, err := tokenResponse.ParseIDToken(); err != nil { log.Error("failed to parse id token", err, "session_id", sessionID, "token", token) return nil, nil } @@ -180,16 +189,18 @@ func (r *redisStore) refreshExpiration(ctx context.Context, sessionID string, ti } type redisToken struct { - IDToken string `redis:"id_token"` - AccessToken string `redis:"access_token"` - RefreshToken string `redis:"refresh_token"` - TimeAdded time.Time `redis:"time_added"` + IDToken string `redis:"id_token"` + AccessToken string `redis:"access_token"` + AccessTokenExpiresAt time.Time `redis:"access_token_expiry"` + RefreshToken string `redis:"refresh_token"` + TimeAdded time.Time `redis:"time_added"` } func (r redisToken) TokenResponse() TokenResponse { return TokenResponse{ - IDToken: r.IDToken, - AccessToken: r.AccessToken, - RefreshToken: r.RefreshToken, + IDToken: r.IDToken, + AccessToken: r.AccessToken, + AccessTokenExpiresAt: r.AccessTokenExpiresAt, + RefreshToken: r.RefreshToken, } } diff --git a/internal/oidc/redis_test.go b/internal/oidc/redis_test.go index f0a8664..fa99e6f 100644 --- a/internal/oidc/redis_test.go +++ b/internal/oidc/redis_test.go @@ -38,15 +38,21 @@ func TestRedisTokenResponse(t *testing.T) { // Create a session and verify it's added and accessed time tr = &TokenResponse{ - IDToken: newToken(), - AccessToken: newToken(), - RefreshToken: newToken(), + IDToken: newToken(), + AccessToken: newToken(), + AccessTokenExpiresAt: time.Now().Add(30 * time.Minute), + RefreshToken: newToken(), } require.NoError(t, store.SetTokenResponse(ctx, "s1", tr)) // Verify we can retrieve the token got, err := store.GetTokenResponse(ctx, "s1") require.NoError(t, err) + // The testify library doesn't properly compare times, so we need to do it manually + // then set the times in the returned object so that we can compare the rest of the + // fields normally + require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt)) + got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt require.Equal(t, tr, got) // Verify that the token TTL has been set @@ -75,3 +81,30 @@ func TestRedisPingError(t *testing.T) { _, err := NewRedisStore(&Clock{}, client, 0, 1*time.Minute) require.EqualError(t, err, "ping error") } + +func TestRefreshExpiration(t *testing.T) { + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + store, err := NewRedisStore(&Clock{}, client, 0, 0) + require.NoError(t, err) + + ctx := context.Background() + + t.Run("delete session if no time added", func(t *testing.T) { + require.NoError(t, client.HSet(ctx, "s1", keyAccessToken, "").Err()) + err := store.(*redisStore).refreshExpiration(ctx, "s1", time.Time{}) + require.ErrorIs(t, err, ErrRedis) + require.Equal(t, redis.Nil, client.Get(ctx, "s1").Err()) + }) + + t.Run("no expiration set if no timeouts", func(t *testing.T) { + require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err()) + require.NoError(t, store.(*redisStore).refreshExpiration(ctx, "s1", time.Time{})) + + res, err := client.TTL(ctx, "s1").Result() + require.NoError(t, err) + require.Equal(t, time.Duration(-1), res) + }) + + // TODO(nacx): Expiration is updated +} diff --git a/internal/oidc/token.go b/internal/oidc/token.go index fd1b8da..71a7871 100644 --- a/internal/oidc/token.go +++ b/internal/oidc/token.go @@ -15,19 +15,20 @@ package oidc import ( + "time" + "github.com/lestrrat-go/jwx/jwt" ) // TokenResponse contains information about the tokens returned by the Identity Provider. type TokenResponse struct { - IDToken string - AccessToken string - RefreshToken string + IDToken string + AccessToken string + AccessTokenExpiresAt time.Time + RefreshToken string } -func (t *TokenResponse) GetIDToken() (jwt.Token, error) { return parse(t.IDToken) } -func (t *TokenResponse) GetAccessToken() (jwt.Token, error) { return parse(t.AccessToken) } -func (t *TokenResponse) GetRefreshToken() (jwt.Token, error) { return parse(t.RefreshToken) } +func (t *TokenResponse) ParseIDToken() (jwt.Token, error) { return parse(t.IDToken) } func parse(token string) (jwt.Token, error) { return jwt.Parse([]byte(token), jwt.WithValidate(false)) diff --git a/internal/oidc/token_test.go b/internal/oidc/token_test.go index 0fb618c..ad4d304 100644 --- a/internal/oidc/token_test.go +++ b/internal/oidc/token_test.go @@ -23,36 +23,20 @@ import ( "github.com/stretchr/testify/require" ) -func TestTokenResponse(t *testing.T) { +func TestParseIDToken(t *testing.T) { t.Run("valid", func(t *testing.T) { tr := &TokenResponse{ - IDToken: newToken(), - AccessToken: newToken(), - RefreshToken: newToken(), + IDToken: newToken(), } - it, err := tr.GetIDToken() + it, err := tr.ParseIDToken() require.NoError(t, err) require.Equal(t, "authservice", it.Issuer()) - - at, err := tr.GetAccessToken() - require.NoError(t, err) - require.Equal(t, "authservice", at.Issuer()) - - rt, err := tr.GetRefreshToken() - require.NoError(t, err) - require.Equal(t, "authservice", rt.Issuer()) }) t.Run("invalid", func(t *testing.T) { tr := &TokenResponse{} - _, err := tr.GetIDToken() - require.Error(t, err) - - _, err = tr.GetAccessToken() - require.Error(t, err) - - _, err = tr.GetRefreshToken() + _, err := tr.ParseIDToken() require.Error(t, err) }) }