Skip to content

Commit

Permalink
feat: Cache requests for JWKS on JWT verification (#228)
Browse files Browse the repository at this point in the history
The jwt.Verify method needs to fetch the JSON Web Key Set from the
API in order to verify the session JWT's validity.
The jwt.Verify method is used in the http.WithHeaderAuthorization
middleware, which means that in an HTTP server context, the method will
executed for every request.
We're adding a caching layer for the JWKS when we verify the session
JWT. This way we can cache the JWKS response from the API for 1 hour.
  • Loading branch information
gkats authored Feb 12, 2024
1 parent 02b7940 commit 326d38b
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 3 deletions.
107 changes: 104 additions & 3 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"

"github.com/clerk/clerk-sdk-go/v2"
Expand Down Expand Up @@ -103,10 +104,9 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er
return claims, nil
}

// Retrieve the JSON web key for the provided id from the set.
func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) {
// TODO Avoid multiple requests by caching results for the same
// instance.
jwks, err := jwks.Get(ctx, &jwks.GetParams{})
jwks, err := getJWKSWithCache(ctx)
if err != nil {
return nil, err
}
Expand All @@ -118,6 +118,44 @@ func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) {
return nil, fmt.Errorf("no jwk key found for kid %s", kid)
}

// Returns the JSON web key set. Tries a cached value first, but if
// there's no value or the entry has expired, it will fetch the set
// from the API and cache the value.
func getJWKSWithCache(ctx context.Context) (*clerk.JSONWebKeySet, error) {
const cacheKey = "/v1/jwks"
var jwks *clerk.JSONWebKeySet
var err error

// Try the cache first. Make sure we have a non-expired entry and
// that the value is a valid JWKS.
entry, ok := getCache().Get(cacheKey)
if ok && !entry.HasExpired() {
jwks, ok = entry.GetValue().(*clerk.JSONWebKeySet)
if !ok || jwks == nil || len(jwks.Keys) == 0 {
jwks, err = forceGetJWKS(ctx, cacheKey)
if err != nil {
return nil, err
}
}
} else {
jwks, err = forceGetJWKS(ctx, cacheKey)
if err != nil {
return nil, err
}
}
return jwks, err
}

// Fetches the JSON web key set from the API and caches it.
func forceGetJWKS(ctx context.Context, cacheKey string) (*clerk.JSONWebKeySet, error) {
jwks, err := jwks.Get(ctx, &jwks.GetParams{})
if err != nil {
return nil, err
}
getCache().Set(cacheKey, jwks, time.Now().UTC().Add(time.Hour))
return jwks, nil
}

func isValidIssuer(iss string) bool {
return strings.HasPrefix(iss, "https://clerk.") ||
strings.Contains(iss, ".clerk.accounts")
Expand Down Expand Up @@ -154,3 +192,66 @@ func Decode(_ context.Context, params *DecodeParams) (*clerk.Claims, error) {
Extra: extraClaims,
}, nil
}

// Caching store.
type cache struct {
mu sync.RWMutex
entries map[string]*cacheEntry
}

// Get returns the cache entry for the provided key, if one exists.
func (c *cache) Get(key string) (*cacheEntry, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, ok := c.entries[key]
return entry, ok
}

// Set adds a new entry with the provided value in the cache under
// the provided key. An expiration date will be set for the entry.
func (c *cache) Set(key string, value any, expiresAt time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[key] = &cacheEntry{
value: value,
expiresAt: expiresAt,
}
}

// A cache entry has a value and an expiration date.
type cacheEntry struct {
value any
expiresAt time.Time
}

// HasExpired returns true if the cache entry's expiration date
// has passed.
func (entry *cacheEntry) HasExpired() bool {
if entry == nil {
return true
}
return entry.expiresAt.Before(time.Now())
}

// GetValue returns the cache entry's value.
func (entry *cacheEntry) GetValue() any {
if entry == nil {
return nil
}
return entry.value
}

var cacheInit sync.Once

// A "singleton" cache for the package.
var defaultCache *cache

// Lazy initialize and return the default cache singleton.
func getCache() *cache {
cacheInit.Do(func() {
defaultCache = &cache{
entries: map[string]*cacheEntry{},
}
})
return defaultCache
}
75 changes: 75 additions & 0 deletions jwt/jwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package jwt

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/clerk/clerk-sdk-go/v2"
"github.com/clerk/clerk-sdk-go/v2/clerktest"
"github.com/stretchr/testify/require"
)

func TestVerify_InvalidToken(t *testing.T) {
clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: &http.Client{
Transport: &clerktest.RoundTripper{},
},
}))

ctx := context.Background()
_, err := Verify(ctx, &VerifyParams{
Token: "this-is-not-a-token",
})
require.Error(t, err)
}

func TestVerify_Cache(t *testing.T) {
ctx := context.Background()
totalRequests := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && r.URL.Path == "/v1/jwks" {
totalRequests++
}
_, err := w.Write([]byte(`{
"keys": [{
"use": "sig",
"kty": "RSA",
"kid": "ins_123",
"alg": "RS256",
"n": "9m1LJW0dgEuK8SnN1Oy4LY8vaWABVS-hBTMA--_4LN1PZlMS5B2RPL85WkXYlHb0KXOSVrFKZLwYP-a9l3MFlW2YrPVAIvYfqPyqY5fmSEf-2qfrwosIhB2NSHyNRBQQ8-BX1RO9rIXIqYDKxGqktqMvYJmEGClmijbmFyQb2hpHD5PDbAB_DZvpZTEzWcQBL2ytHehILkYfg-ZZRyt7O8h5Gdy1v_TUlg8iMvchHlAkrIAmXNQigZmX_lne91tW8t4KMNJRfmUyLVCLbPnwxlmXXcice-0tmFw0OkCOteNWBeRNctJ3AIreGMzaJOJ2HeSUmJoX8iRKLLT3fsURLw",
"e": "AQAB"
}]
}`))
require.NoError(t, err)
}))
defer ts.Close()

clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{
HTTPClient: ts.Client(),
URL: clerk.String(ts.URL),
}))

token := "eyJhbGciOiJSUzI1NiIsImNhdCI6ImNsX0I3ZDRQRDExMUFBQSIsImtpZCI6Imluc18yOWR6bUdmQ3JydzdSMDRaVFFZRDNKSTB5dkYiLCJ0eXAiOiJKV1QifQ.eyJhenAiOiJodHRwczovL2Rhc2hib2FyZC5wcm9kLmxjbGNsZXJrLmNvbSIsImV4cCI6MTcwNzMwMDMyMiwiaWF0IjoxNzA3MzAwMjYyLCJpc3MiOiJodHRwczovL2NsZXJrLnByb2QubGNsY2xlcmsuY29tIiwibmJmIjoxNzA3MzAwMjUyLCJvcmdzIjp7Im9yZ18ySUlwcVIxenFNeHJQQkhSazNzTDJOSnJUQkQiOiJvcmc6YWRtaW4iLCJvcmdfMllHMlNwd0IzWEJoNUo0ZXF5elFVb0dXMjVhIjoib3JnOmFkbWluIiwib3JnXzJhZzJ6bmgxWGFjTXI0dGRXYjZRbEZSQ2RuaiI6Im9yZzphZG1pbiIsIm9yZ18yYWlldHlXa3VFSEhaRmRSUTFvVjYzMnZWaFciOiJvcmc6YWRtaW4ifSwic2lkIjoic2Vzc18yYm84b2gyRnIyeTNueVoyRVZQYktBd2ZvaU0iLCJzdWIiOiJ1c2VyXzI5ZTBXTnp6M245V1Q5S001WlpJYTBVVjNDNyJ9.6GtQafMBYY3Ij3pKHOyBYKt76LoLeBC71QUY_ho3k5nb0FBSvV0upKFLPBvIXNuF7hH0FK2QqDcAmrhbzAI-2qF_Ynve8Xl4VZCRpbTuZI7uL-tVjCvMffEIH-BHtrZ-QcXhEmNFQNIPyZTu21242he7U6o4S8st_aLmukWQzj_4qir7o5_fmVhm7YkLa0gYG5SLjkr2czwem1VGFHEVEOrHjun-g6eMnDNMMMysIOkZFxeqiCnqpc4u1V7Z7jfoK0r_-Unp8mGGln5KWYMCQyp1l1SkGwugtxeWfSbE4eklKRmItGOdVftvTyG16kDGpzsb22AQGtg65Iygni4PHg"
// Providing a custom key will not trigger a request to fetch the
// key set.
_, _ = Verify(ctx, &VerifyParams{
Token: token,
JWK: &clerk.JSONWebKey{},
})
require.Equal(t, 0, totalRequests)

// Verify without providing a key. The method will trigger a request
// to fetch the key set.
_, _ = Verify(ctx, &VerifyParams{
Token: token,
})
require.Equal(t, 1, totalRequests)
// Verifying again won't trigger a request because the key set is
// cached.
_, _ = Verify(ctx, &VerifyParams{
Token: token,
})
require.Equal(t, 1, totalRequests)
}

0 comments on commit 326d38b

Please sign in to comment.