Skip to content

Commit

Permalink
Cache access tokens client side (#3565)
Browse files Browse the repository at this point in the history
  • Loading branch information
woutslakhorst authored Nov 22, 2024
1 parent 18f72f1 commit 0fa8fb5
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 45 deletions.
46 changes: 45 additions & 1 deletion auth/api/iam/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import (
"bytes"
"context"
"crypto"
"crypto/sha256"
"embed"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -72,6 +74,10 @@ type httpRequestContextKey struct{}
// TODO: Might want to make this configurable at some point
const accessTokenValidity = 15 * time.Minute

// accessTokenCacheOffset is used to reduce the ttl of the access token to ensure it is still valid when the client receives it.
// this to offset clock skew and roundtrip times
const accessTokenCacheOffset = 30 * time.Second

// cacheControlMaxAgeURLs holds API endpoints that should have a max-age cache control header set.
var cacheControlMaxAgeURLs = []string{
"/oauth2/:subjectID/presentation_definition",
Expand Down Expand Up @@ -722,6 +728,19 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
return nil, err
}

tokenCache := r.accessTokenCache()
cacheKey := accessTokenRequestCacheKey(request)

// try to retrieve token from cache
tokenResult := new(TokenResponse)
err = tokenCache.Get(cacheKey, tokenResult)
if err == nil {
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
} else if !errors.Is(err, storage.ErrNotFound) {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to retrieve access token from cache: %s", err.Error())
}

var credentials []VerifiableCredential
if request.Body.Credentials != nil {
credentials = *request.Body.Credentials
Expand All @@ -732,11 +751,22 @@ func (r Wrapper) RequestServiceAccessToken(ctx context.Context, request RequestS
useDPoP = false
}
clientID := r.subjectToBaseURL(request.SubjectID)
tokenResult, err := r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
tokenResult, err = r.auth.IAMClient().RequestRFC021AccessToken(ctx, clientID.String(), request.SubjectID, request.Body.AuthorizationServer, request.Body.Scope, useDPoP, credentials)
if err != nil {
// this can be an internal server error, a 400 oauth error or a 412 precondition failed if the wallet does not contain the required credentials
return nil, err
}
ttl := accessTokenValidity
if tokenResult.ExpiresIn != nil {
ttl = time.Second * time.Duration(*tokenResult.ExpiresIn)
}
// we reduce the ttl by accessTokenCacheOffset to make sure the token is expired when the cache expires
ttl -= accessTokenCacheOffset
err = tokenCache.Put(cacheKey, tokenResult, storage.WithTTL(ttl))
if err != nil {
// only log error, don't fail
log.Logger().WithError(err).Warnf("Failed to cache access token: %s", err.Error())
}
return RequestServiceAccessToken200JSONResponse(*tokenResult), nil
}

Expand Down Expand Up @@ -897,6 +927,12 @@ func (r Wrapper) accessTokenServerStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, "serveraccesstoken")
}

// accessTokenClientStore is used by the client to cache access tokens
func (r Wrapper) accessTokenCache() storage.SessionStore {
// we use a slightly reduced validity to prevent the cache from being used after the token has expired
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity-accessTokenCacheOffset, "accesstokencache")
}

// accessTokenServerStore is used by the Auth server to store issued access tokens
func (r Wrapper) authzRequestObjectStore() storage.SessionStore {
return r.storageEngine.GetSessionDatabase().GetStore(accessTokenValidity, oauthRequestObjectKey...)
Expand Down Expand Up @@ -946,3 +982,11 @@ func (r Wrapper) determineClientDID(ctx context.Context, authServerMetadata oaut
}
return &candidateDIDs[0], nil
}

// accessTokenRequestCacheKey creates a cache key for the access token request.
// it writes the JSON to a sha256 hash and returns the hex encoded hash.
func accessTokenRequestCacheKey(request RequestServiceAccessTokenRequestObject) string {
hash := sha256.New()
_ = json.NewEncoder(hash).Encode(request)
return hex.EncodeToString(hash.Sum(nil))
}
71 changes: 64 additions & 7 deletions auth/api/iam/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/nuts-foundation/nuts-node/core/to"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -51,10 +45,15 @@ import (
"github.com/nuts-foundation/nuts-node/auth/oauth"
oauthServices "github.com/nuts-foundation/nuts-node/auth/services/oauth"
"github.com/nuts-foundation/nuts-node/core"
"github.com/nuts-foundation/nuts-node/core/to"
cryptoNuts "github.com/nuts-foundation/nuts-node/crypto"
"github.com/nuts-foundation/nuts-node/crypto/storage/spi"
test2 "github.com/nuts-foundation/nuts-node/crypto/test"
"github.com/nuts-foundation/nuts-node/http/user"
"github.com/nuts-foundation/nuts-node/jsonld"
"github.com/nuts-foundation/nuts-node/policy"
"github.com/nuts-foundation/nuts-node/storage"
"github.com/nuts-foundation/nuts-node/test"
"github.com/nuts-foundation/nuts-node/vcr"
"github.com/nuts-foundation/nuts-node/vcr/credential"
"github.com/nuts-foundation/nuts-node/vcr/holder"
Expand All @@ -63,6 +62,7 @@ import (
"github.com/nuts-foundation/nuts-node/vcr/types"
"github.com/nuts-foundation/nuts-node/vcr/verifier"
"github.com/nuts-foundation/nuts-node/vdr"
"github.com/nuts-foundation/nuts-node/vdr/didsubject"
"github.com/nuts-foundation/nuts-node/vdr/resolver"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -865,11 +865,31 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

t.Run("ok", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body})
token, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

t.Run("is cached", func(t *testing.T) {
cachedToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.Equal(t, token, cachedToken)
})

t.Run("cache expired", func(t *testing.T) {
cacheKey := accessTokenRequestCacheKey(request)
_ = ctx.client.accessTokenCache().Delete(cacheKey)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "other"}, nil)

otherToken, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)

assert.NotEqual(t, token, otherToken)
})
})
t.Run("ok - no DPoP", func(t *testing.T) {
ctx := newTestClient(t)
Expand All @@ -885,6 +905,16 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {

require.NoError(t, err)
})
t.Run("ok with expired cache by ttl", func(t *testing.T) {
ctx := newTestClient(t)
request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{ExpiresIn: to.Ptr(5)}, nil)

_, err := ctx.client.RequestServiceAccessToken(nil, request)

require.NoError(t, err)
assert.False(t, ctx.client.accessTokenCache().Exists(accessTokenRequestCacheKey(request)))
})
t.Run("error - no matching credentials", func(t *testing.T) {
ctx := newTestClient(t)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(nil, pe.ErrNoCredentials)
Expand All @@ -895,6 +925,24 @@ func TestWrapper_RequestServiceAccessToken(t *testing.T) {
assert.Equal(t, err, pe.ErrNoCredentials)
assert.Equal(t, http.StatusPreconditionFailed, statusCodeFrom(err))
})
t.Run("broken cache", func(t *testing.T) {
ctx := newTestClient(t)
mockStorage := storage.NewMockEngine(ctx.ctrl)
errorSessionDatabase := storage.NewErrorSessionDatabase(assert.AnError)
mockStorage.EXPECT().GetSessionDatabase().Return(errorSessionDatabase).AnyTimes()
ctx.client.storageEngine = mockStorage

request := RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: body}
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "first"}, nil)
ctx.iamClient.EXPECT().RequestRFC021AccessToken(nil, holderClientID, holderSubjectID, verifierURL.String(), "first second", true, nil).Return(&oauth.TokenResponse{AccessToken: "second"}, nil)

token1, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)
token2, err := ctx.client.RequestServiceAccessToken(nil, request)
require.NoError(t, err)

assert.NotEqual(t, token1, token2)
})
}

func TestWrapper_RequestUserAccessToken(t *testing.T) {
Expand Down Expand Up @@ -1320,6 +1368,15 @@ func TestWrapper_subjectOwns(t *testing.T) {
})
}

func TestWrapper_accessTokenRequestCacheKey(t *testing.T) {
expected := "0cc6fbbd972c72de7bc86c6147347bdd54bcb41fe23cea3d8f61d6ddd75dbf86"
key := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test"}})
other := accessTokenRequestCacheKey(RequestServiceAccessTokenRequestObject{SubjectID: holderSubjectID, Body: &RequestServiceAccessTokenJSONRequestBody{Scope: "test2"}})

assert.Equal(t, expected, key)
assert.NotEqual(t, key, other)
}

func createIssuerCredential(issuerDID did.DID, holderDID did.DID) *vc.VerifiableCredential {
privateKey, _ := spi.GenerateKeyPair()
credType := ssi.MustParseURI("ExampleType")
Expand Down
2 changes: 1 addition & 1 deletion storage/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (e *engine) Shutdown() error {
}

// Close session database
e.sessionDatabase.close()
e.sessionDatabase.Close()
// Close SQL db
if e.sqlDB != nil {
underlyingDB, err := e.sqlDB.DB()
Expand Down
23 changes: 20 additions & 3 deletions storage/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ type SessionDatabase interface {
// The keys are used to logically partition the store, eg: tenants and/or flows that are not allowed to overlap like credential issuance and verification.
// The TTL is the time-to-live for the entries in the store.
GetStore(ttl time.Duration, keys ...string) SessionStore
// close stops any background processes and closes the database.
close()
// getFullKey returns the full key for the given key and prefixes.
// the supported chars differ per backend.
getFullKey(prefixes []string, key string) string
// Close stops any background processes and closes the database.
Close()
}

// SessionStore is a key-value store that holds session data.
Expand All @@ -95,10 +97,25 @@ type SessionStore interface {
// Returns ErrNotFound if the key does not exist.
Get(key string, target interface{}) error
// Put stores the given value for the given key.
Put(key string, value interface{}) error
// options can be used to fine-tune the storage of the item.
Put(key string, value interface{}, options ...SessionOption) error
// GetAndDelete combines Get and Delete as a convenience for burning nonce entries.
GetAndDelete(key string, target interface{}) error
}

// TransactionKey is the key used to store the SQL transaction in the context.
type TransactionKey struct{}

// SessionOption is an option that can be given when storing items.
type SessionOption func(target *sessionOptions)

type sessionOptions struct {
ttl time.Duration
}

// WithTTL sets the time-to-live for the stored item.
func WithTTL(ttl time.Duration) SessionOption {
return func(target *sessionOptions) {
target.ttl = ttl
}
}
37 changes: 21 additions & 16 deletions storage/mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 17 additions & 2 deletions storage/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,31 @@ func (s SessionStoreImpl[T]) Get(key string, target interface{}) error {
return json.Unmarshal([]byte(val), target)
}

func (s SessionStoreImpl[T]) Put(key string, value interface{}) error {
func (s SessionStoreImpl[T]) Put(key string, value interface{}, options ...SessionOption) error {
opts := s.defaultOptions()
for _, opt := range options {
opt(&opts)
}
// TTL can't go below 0 because that is translated to "no expiration" by the library
// so just don't cache
if opts.ttl <= 0 {
return nil
}
bytes, err := json.Marshal(value)
if err != nil {
return err
}
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(s.ttl))
return s.underlying.Set(context.Background(), s.db.getFullKey(s.prefixes, key), T(bytes), store.WithExpiration(opts.ttl))
}
func (s SessionStoreImpl[T]) GetAndDelete(key string, target interface{}) error {
if err := s.Get(key, target); err != nil {
return err
}
return s.underlying.Delete(context.Background(), s.db.getFullKey(s.prefixes, key))
}

func (s SessionStoreImpl[T]) defaultOptions() sessionOptions {
return sessionOptions{
ttl: s.ttl,
}
}
2 changes: 1 addition & 1 deletion storage/session_inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *InMemorySessionDatabase) GetStore(ttl time.Duration, keys ...string) Se
}
}

func (s *InMemorySessionDatabase) close() {
func (s *InMemorySessionDatabase) Close() {
// NOP
}

Expand Down
Loading

0 comments on commit 0fa8fb5

Please sign in to comment.