Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Commit

Permalink
Fix races in jwks service (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
nacx authored Feb 29, 2024
1 parent 50ba199 commit e0e657a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 35 deletions.
28 changes: 11 additions & 17 deletions internal/oidc/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
// ErrJWKSFetch is returned when the JWKS document cannot be fetched.
ErrJWKSFetch = errors.New("error fetching JWKS document")

_ run.Service = (*DefaultJWKSProvider)(nil)
_ run.ServiceContext = (*DefaultJWKSProvider)(nil)
)

// DefaultFetchInterval is the default interval to use when none is set.
Expand All @@ -49,31 +49,31 @@ type JWKSProvider interface {

// DefaultJWKSProvider provides a JWKS set
type DefaultJWKSProvider struct {
log telemetry.Logger
cache *jwk.AutoRefresh
shutdown context.CancelFunc
tlsPool internal.TLSConfigPool
log telemetry.Logger
cache *jwk.AutoRefresh
tlsPool internal.TLSConfigPool
started chan struct{}
}

// NewJWKSProvider returns a new JWKSProvider.
func NewJWKSProvider(tlsPool internal.TLSConfigPool) *DefaultJWKSProvider {
return &DefaultJWKSProvider{
log: internal.Logger(internal.JWKS),
tlsPool: tlsPool,
started: make(chan struct{}),
}
}

// Name of the JWKSProvider run.Unit
func (j *DefaultJWKSProvider) Name() string { return "JWKS" }

// Serve implements run.Service
func (j *DefaultJWKSProvider) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
j.shutdown = cancel

func (j *DefaultJWKSProvider) ServeContext(ctx context.Context) error {
ch := make(chan jwk.AutoRefreshError)
j.cache = jwk.NewAutoRefresh(ctx)
j.cache.ErrorSink(ch)
defer func() { close(ch) }()

close(j.started) // signal channel start

for {
select {
Expand All @@ -85,16 +85,10 @@ func (j *DefaultJWKSProvider) Serve() error {
}
}

// GracefulStop implements run.Service
func (j *DefaultJWKSProvider) GracefulStop() {
if j.shutdown != nil {
j.shutdown()
}
}

// Get the JWKS for the given OIDC configuration
func (j *DefaultJWKSProvider) Get(ctx context.Context, config *oidcv1.OIDCConfig) (jwk.Set, error) {
if config.GetJwksFetcher() != nil {
<-j.started // wait until the service is fully started
return j.fetchDynamic(ctx, config)
}
return j.fetchStatic(config.GetJwks())
Expand Down
37 changes: 19 additions & 18 deletions internal/oidc/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -80,9 +81,10 @@ func TestStaticJWKSProvider(t *testing.T) {
tlsPool := internal.NewTLSConfigPool(context.Background())

t.Run("invalid", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cache := NewJWKSProvider(tlsPool)
go func() { require.NoError(t, cache.Serve()) }()
t.Cleanup(cache.GracefulStop)
go func() { require.NoError(t, cache.ServeContext(ctx)) }()
t.Cleanup(cancel)

_, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{
JwksConfig: &oidcv1.OIDCConfig_Jwks{
Expand All @@ -94,9 +96,10 @@ func TestStaticJWKSProvider(t *testing.T) {
})

t.Run("single-key", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cache := NewJWKSProvider(tlsPool)
go func() { require.NoError(t, cache.Serve()) }()
t.Cleanup(cache.GracefulStop)
go func() { require.NoError(t, cache.ServeContext(ctx)) }()
t.Cleanup(cancel)

jwks, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{
JwksConfig: &oidcv1.OIDCConfig_Jwks{
Expand All @@ -115,9 +118,10 @@ func TestStaticJWKSProvider(t *testing.T) {
})

t.Run("multiple-keys", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cache := NewJWKSProvider(tlsPool)
go func() { require.NoError(t, cache.Serve()) }()
t.Cleanup(cache.GracefulStop)
go func() { require.NoError(t, cache.ServeContext(ctx)) }()
t.Cleanup(cancel)

jwks, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{
JwksConfig: &oidcv1.OIDCConfig_Jwks{
Expand Down Expand Up @@ -153,12 +157,9 @@ func TestDynamicJWKSProvider(t *testing.T) {
g := run.Group{Logger: telemetry.NoopLogger()}
g.Register(cache)
go func() { _ = g.Run() }()
t.Cleanup(cache.GracefulStop)

// Block until the cache is initialized
require.Eventually(t, func() bool {
return cache.cache != nil
}, 10*time.Second, 50*time.Millisecond)
<-cache.started
return cache
}
)
Expand All @@ -178,7 +179,7 @@ func TestDynamicJWKSProvider(t *testing.T) {
_, err := cache.Get(context.Background(), config)

require.ErrorIs(t, err, ErrJWKSFetch)
require.Equal(t, 1, server.requestCount) // The attempt to load the JWKS is made, but fails
require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount)) // The attempt to load the JWKS is made, but fails
})

t.Run("cache load", func(t *testing.T) {
Expand All @@ -198,7 +199,7 @@ func TestDynamicJWKSProvider(t *testing.T) {
keys, err := cache.Get(context.Background(), config)
require.NoError(t, err)
require.Equal(t, jwks, keys)
require.Equal(t, 1, server.requestCount)
require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount))
})

t.Run("cached results", func(t *testing.T) {
Expand All @@ -218,7 +219,7 @@ func TestDynamicJWKSProvider(t *testing.T) {
keys, err := cache.Get(context.Background(), config)
require.NoError(t, err)
require.Equal(t, jwks, keys)
require.Equal(t, 1, server.requestCount) // Cached results after the first request
require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount)) // Cached results after the first request
}
})

Expand All @@ -242,20 +243,20 @@ func TestDynamicJWKSProvider(t *testing.T) {

// Wait for the refresh period and check that the JWKS has been refreshed
require.Eventually(t, func() bool {
return server.requestCount > 1
return atomic.LoadInt32(server.requestCount) > 1
}, 3*time.Second, 1*time.Second)
})
}

type server struct {
*httptest.Server
requestCount int
requestCount *int32
}

func newTestServer(t *testing.T, jwks jwk.Set) *server {
s := &server{}
s := &server{requestCount: new(int32)}
s.Server = httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
s.requestCount++
atomic.AddInt32(s.requestCount, 1)

if strings.HasSuffix(req.URL.Path, "/not-found") {
res.WriteHeader(404)
Expand All @@ -267,7 +268,7 @@ func newTestServer(t *testing.T, jwks jwk.Set) *server {
res.WriteHeader(200)
_, _ = res.Write(bytes)
}))
t.Cleanup(func() { s.requestCount = 0 })
t.Cleanup(func() { atomic.StoreInt32(s.requestCount, 0) })
t.Cleanup(s.Close)
return s
}
Expand Down

0 comments on commit e0e657a

Please sign in to comment.