From e0e657a9d08a1e90ac33197e5d216e799f8d8584 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Thu, 29 Feb 2024 12:14:08 +0530 Subject: [PATCH] Fix races in jwks service (#65) --- internal/oidc/jwks.go | 28 +++++++++++----------------- internal/oidc/jwks_test.go | 37 +++++++++++++++++++------------------ 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/internal/oidc/jwks.go b/internal/oidc/jwks.go index 353738d..546c3b6 100644 --- a/internal/oidc/jwks.go +++ b/internal/oidc/jwks.go @@ -35,7 +35,7 @@ var ( // ErrJWKSFetch is returned when the JWKS document cannot be fetched. ErrJWKSFetch = errors.New("error fetching JWKS document") - _ run.Service = (*DefaultJWKSProvider)(nil) + _ run.ServiceContext = (*DefaultJWKSProvider)(nil) ) // DefaultFetchInterval is the default interval to use when none is set. @@ -49,10 +49,10 @@ type JWKSProvider interface { // DefaultJWKSProvider provides a JWKS set type DefaultJWKSProvider struct { - log telemetry.Logger - cache *jwk.AutoRefresh - shutdown context.CancelFunc - tlsPool internal.TLSConfigPool + log telemetry.Logger + cache *jwk.AutoRefresh + tlsPool internal.TLSConfigPool + started chan struct{} } // NewJWKSProvider returns a new JWKSProvider. @@ -60,20 +60,20 @@ func NewJWKSProvider(tlsPool internal.TLSConfigPool) *DefaultJWKSProvider { return &DefaultJWKSProvider{ log: internal.Logger(internal.JWKS), tlsPool: tlsPool, + started: make(chan struct{}), } } // Name of the JWKSProvider run.Unit func (j *DefaultJWKSProvider) Name() string { return "JWKS" } -// Serve implements run.Service -func (j *DefaultJWKSProvider) Serve() error { - ctx, cancel := context.WithCancel(context.Background()) - j.shutdown = cancel - +func (j *DefaultJWKSProvider) ServeContext(ctx context.Context) error { ch := make(chan jwk.AutoRefreshError) j.cache = jwk.NewAutoRefresh(ctx) j.cache.ErrorSink(ch) + defer func() { close(ch) }() + + close(j.started) // signal channel start for { select { @@ -85,16 +85,10 @@ func (j *DefaultJWKSProvider) Serve() error { } } -// GracefulStop implements run.Service -func (j *DefaultJWKSProvider) GracefulStop() { - if j.shutdown != nil { - j.shutdown() - } -} - // Get the JWKS for the given OIDC configuration func (j *DefaultJWKSProvider) Get(ctx context.Context, config *oidcv1.OIDCConfig) (jwk.Set, error) { if config.GetJwksFetcher() != nil { + <-j.started // wait until the service is fully started return j.fetchDynamic(ctx, config) } return j.fetchStatic(config.GetJwks()) diff --git a/internal/oidc/jwks_test.go b/internal/oidc/jwks_test.go index f25c5b0..5f772b2 100644 --- a/internal/oidc/jwks_test.go +++ b/internal/oidc/jwks_test.go @@ -22,6 +22,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -80,9 +81,10 @@ func TestStaticJWKSProvider(t *testing.T) { tlsPool := internal.NewTLSConfigPool(context.Background()) t.Run("invalid", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) cache := NewJWKSProvider(tlsPool) - go func() { require.NoError(t, cache.Serve()) }() - t.Cleanup(cache.GracefulStop) + go func() { require.NoError(t, cache.ServeContext(ctx)) }() + t.Cleanup(cancel) _, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{ JwksConfig: &oidcv1.OIDCConfig_Jwks{ @@ -94,9 +96,10 @@ func TestStaticJWKSProvider(t *testing.T) { }) t.Run("single-key", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) cache := NewJWKSProvider(tlsPool) - go func() { require.NoError(t, cache.Serve()) }() - t.Cleanup(cache.GracefulStop) + go func() { require.NoError(t, cache.ServeContext(ctx)) }() + t.Cleanup(cancel) jwks, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{ JwksConfig: &oidcv1.OIDCConfig_Jwks{ @@ -115,9 +118,10 @@ func TestStaticJWKSProvider(t *testing.T) { }) t.Run("multiple-keys", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) cache := NewJWKSProvider(tlsPool) - go func() { require.NoError(t, cache.Serve()) }() - t.Cleanup(cache.GracefulStop) + go func() { require.NoError(t, cache.ServeContext(ctx)) }() + t.Cleanup(cancel) jwks, err := cache.Get(context.Background(), &oidcv1.OIDCConfig{ JwksConfig: &oidcv1.OIDCConfig_Jwks{ @@ -153,12 +157,9 @@ func TestDynamicJWKSProvider(t *testing.T) { g := run.Group{Logger: telemetry.NoopLogger()} g.Register(cache) go func() { _ = g.Run() }() - t.Cleanup(cache.GracefulStop) // Block until the cache is initialized - require.Eventually(t, func() bool { - return cache.cache != nil - }, 10*time.Second, 50*time.Millisecond) + <-cache.started return cache } ) @@ -178,7 +179,7 @@ func TestDynamicJWKSProvider(t *testing.T) { _, err := cache.Get(context.Background(), config) require.ErrorIs(t, err, ErrJWKSFetch) - require.Equal(t, 1, server.requestCount) // The attempt to load the JWKS is made, but fails + require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount)) // The attempt to load the JWKS is made, but fails }) t.Run("cache load", func(t *testing.T) { @@ -198,7 +199,7 @@ func TestDynamicJWKSProvider(t *testing.T) { keys, err := cache.Get(context.Background(), config) require.NoError(t, err) require.Equal(t, jwks, keys) - require.Equal(t, 1, server.requestCount) + require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount)) }) t.Run("cached results", func(t *testing.T) { @@ -218,7 +219,7 @@ func TestDynamicJWKSProvider(t *testing.T) { keys, err := cache.Get(context.Background(), config) require.NoError(t, err) require.Equal(t, jwks, keys) - require.Equal(t, 1, server.requestCount) // Cached results after the first request + require.Equal(t, int32(1), atomic.LoadInt32(server.requestCount)) // Cached results after the first request } }) @@ -242,20 +243,20 @@ func TestDynamicJWKSProvider(t *testing.T) { // Wait for the refresh period and check that the JWKS has been refreshed require.Eventually(t, func() bool { - return server.requestCount > 1 + return atomic.LoadInt32(server.requestCount) > 1 }, 3*time.Second, 1*time.Second) }) } type server struct { *httptest.Server - requestCount int + requestCount *int32 } func newTestServer(t *testing.T, jwks jwk.Set) *server { - s := &server{} + s := &server{requestCount: new(int32)} s.Server = httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - s.requestCount++ + atomic.AddInt32(s.requestCount, 1) if strings.HasSuffix(req.URL.Path, "/not-found") { res.WriteHeader(404) @@ -267,7 +268,7 @@ func newTestServer(t *testing.T, jwks jwk.Set) *server { res.WriteHeader(200) _, _ = res.Write(bytes) })) - t.Cleanup(func() { s.requestCount = 0 }) + t.Cleanup(func() { atomic.StoreInt32(s.requestCount, 0) }) t.Cleanup(s.Close) return s }