From 9d0e3d382444937431905bdad4f2bff789b36b86 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Wed, 14 Feb 2024 10:21:19 +0100 Subject: [PATCH] Load the session stores for all OIDC configurations (#10) --- .github/codecov.yml | 7 ++- cmd/main.go | 4 +- internal/authz/oidc.go | 27 ++++------- internal/oidc/jwks_test.go | 10 ++-- internal/oidc/redis.go | 25 ++++++++++ internal/oidc/session.go | 61 ++++++++++++++++++++++++ internal/oidc/session_test.go | 90 +++++++++++++++++++++++++++++++++++ internal/server/authz.go | 18 +++---- internal/server/authz_test.go | 10 ++-- 9 files changed, 216 insertions(+), 36 deletions(-) create mode 100644 internal/oidc/redis.go create mode 100644 internal/oidc/session_test.go diff --git a/.github/codecov.yml b/.github/codecov.yml index ab9da73..b1e997a 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -10,7 +10,7 @@ comment: layout: "diff,files" ignore: - - "api/**/*" + - "config/gen/**/*" coverage: status: @@ -20,3 +20,8 @@ coverage: target: auto # allow a potential drop of up to 5% threshold: 5% + patch: + default: + target: auto + only_pulls: true + threshold: 0% diff --git a/cmd/main.go b/cmd/main.go index 7f3f235..b9832d7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -33,7 +33,8 @@ func main() { configFile = &internal.LocalConfigFile{} logging = internal.NewLogSystem(log.New(), &configFile.Config) jwks = oidc.NewJWKSProvider() - envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks) + sessions = &oidc.SessionStoreFactory{Config: &configFile.Config} + envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks, sessions) authzServer = server.New(&configFile.Config, envoyAuthz.Register) ) @@ -52,6 +53,7 @@ func main() { logging, // set up the logging system configLog, // log the configuration jwks, // start the JWKS provider + sessions, // start the session store authzServer, // start the server &signal.Handler{}, // handle graceful termination ) diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index bd9706f..8a9bffc 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -16,7 +16,6 @@ package authz import ( "context" - "time" envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" "github.com/tetratelabs/telemetry" @@ -31,27 +30,19 @@ var _ Handler = (*oidcHandler)(nil) // oidc handler is an implementation of the Handler interface that implements // the OpenID connect protocol. type oidcHandler struct { - log telemetry.Logger - config *oidcv1.OIDCConfig - store oidc.SessionStore - jwks oidc.JWKSProvider + log telemetry.Logger + config *oidcv1.OIDCConfig + jwks oidc.JWKSProvider + sessions *oidc.SessionStoreFactory } // NewOIDCHandler creates a new OIDC implementation of the Handler interface. -func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider) (Handler, error) { - // TODO(nacx): Read the redis store config to configure the redi store - // TODO(nacx): Properly lifecycle the session store - store := oidc.NewMemoryStore( - oidc.Clock{}, - time.Duration(cfg.AbsoluteSessionTimeout), - time.Duration(cfg.IdleSessionTimeout), - ) - +func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) (Handler, error) { return &oidcHandler{ - log: internal.Logger(internal.Authz).With("type", "oidc"), - config: cfg, - store: store, - jwks: jwks, + log: internal.Logger(internal.Authz).With("type", "oidc"), + config: cfg, + jwks: jwks, + sessions: sessions, }, nil } diff --git a/internal/oidc/jwks_test.go b/internal/oidc/jwks_test.go index 9cfe67f..aa1d6ec 100644 --- a/internal/oidc/jwks_test.go +++ b/internal/oidc/jwks_test.go @@ -28,6 +28,8 @@ import ( "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwk" "github.com/stretchr/testify/require" + "github.com/tetratelabs/run" + "github.com/tetratelabs/telemetry" oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" ) @@ -143,8 +145,11 @@ func TestDynamicJWKSProvider(t *testing.T) { newCache = func(t *testing.T) JWKSProvider { cache := NewJWKSProvider() - go func() { require.NoError(t, cache.Serve()) }() + g := run.Group{Logger: telemetry.NoopLogger()} + g.Register(cache) + go func() { _ = g.Run() }() t.Cleanup(cache.GracefulStop) + // Block until the cache is initialized require.Eventually(t, func() bool { return cache.cache != nil @@ -160,8 +165,7 @@ func TestDynamicJWKSProvider(t *testing.T) { config := &oidcv1.OIDCConfig{ JwksConfig: &oidcv1.OIDCConfig_JwksFetcher{ JwksFetcher: &oidcv1.OIDCConfig_JwksFetcherConfig{ - JwksUri: server.URL + "/not-found", - PeriodicFetchIntervalSec: 1, + JwksUri: server.URL + "/not-found", }, }, } diff --git a/internal/oidc/redis.go b/internal/oidc/redis.go new file mode 100644 index 0000000..b65de68 --- /dev/null +++ b/internal/oidc/redis.go @@ -0,0 +1,25 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +var _ SessionStore = (*redisStore)(nil) + +// redisStore is an in-memory implementation of the SessionStore interface that stores +// the session data in a given Redis server. +type redisStore struct { + // TODO(nacx): Remove the interface embedding and implement it + SessionStore + url string +} diff --git a/internal/oidc/session.go b/internal/oidc/session.go index 9c0cba7..4013fec 100644 --- a/internal/oidc/session.go +++ b/internal/oidc/session.go @@ -14,6 +14,15 @@ package oidc +import ( + "time" + + "github.com/tetratelabs/run" + + configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" + oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" +) + // SessionStore is an interface for storing session data. type SessionStore interface { SetTokenResponse(sessionID string, tokenResponse *TokenResponse) @@ -24,3 +33,55 @@ type SessionStore interface { RemoveSession(sessionID string) RemoveAllExpired() } + +var _ run.PreRunner = (*SessionStoreFactory)(nil) + +// SessionStoreFactory is a factory for creating session stores. +// It uses the OIDC configuration to determine which store to use. +type SessionStoreFactory struct { + Config *configv1.Config + + redis map[string]SessionStore + memory SessionStore +} + +// Name implements run.Unit. +func (s *SessionStoreFactory) Name() string { return "OIDC session store factory" } + +// PreRun initializes the stores that are defined in the configuration +func (s *SessionStoreFactory) PreRun() error { + s.redis = make(map[string]SessionStore) + + for _, fc := range s.Config.Chains { + for _, f := range fc.Filters { + if f.GetOidc() == nil { + continue + } + + if redisServer := f.GetOidc().GetRedisSessionStoreConfig().GetServerUri(); redisServer != "" { + // TODO(nacx): Initialize the Redis store + s.redis[redisServer] = &redisStore{url: redisServer} + } else if s.memory == nil { // Use a shared in-memory store for all OIDC configurations + s.memory = NewMemoryStore( + Clock{}, + time.Duration(f.GetOidc().GetAbsoluteSessionTimeout()), + time.Duration(f.GetOidc().GetIdleSessionTimeout()), + ) + } + } + } + + return nil +} + +// Get returns the appropriate session store for the given OIDC configuration. +func (s *SessionStoreFactory) Get(cfg *oidcv1.OIDCConfig) SessionStore { + if cfg == nil { + return nil + } + store, ok := s.redis[cfg.GetRedisSessionStoreConfig().GetServerUri()] + if !ok { + store = s.memory + } + return store +} diff --git a/internal/oidc/session_test.go b/internal/oidc/session_test.go new file mode 100644 index 0000000..0fee1b3 --- /dev/null +++ b/internal/oidc/session_test.go @@ -0,0 +1,90 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tetratelabs/run" + "github.com/tetratelabs/telemetry" + + configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" + mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock" + oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" +) + +func TestSessionStoreFactory(t *testing.T) { + config := &configv1.Config{ + ListenAddress: "0.0.0.0", + ListenPort: 8080, + LogLevel: "debug", + Threads: 1, + Chains: []*configv1.FilterChain{ + { + Name: "memory1", + Filters: []*configv1.Filter{ + {Type: &configv1.Filter_Mock{Mock: &mockv1.MockConfig{}}}, + {Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}}, + }, + }, + { + Name: "memory2", + Filters: []*configv1.Filter{ + {Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}}, + }, + }, + { + Name: "redis1", + Filters: []*configv1.Filter{ + { + Type: &configv1.Filter_Oidc{ + Oidc: &oidcv1.OIDCConfig{ + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis1:6379"}, + }, + }, + }, + }, + }, + { + Name: "redis2", + Filters: []*configv1.Filter{ + { + Type: &configv1.Filter_Oidc{ + Oidc: &oidcv1.OIDCConfig{ + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis2:6379"}, + }, + }, + }, + }, + }, + }, + } + + store := SessionStoreFactory{Config: config} + g := run.Group{Logger: telemetry.NoopLogger()} + g.Register(&store) + require.NoError(t, g.Run()) + + require.NotNil(t, store.memory) + require.Len(t, store.redis, 2) + + require.Nil(t, store.Get(nil)) + require.IsType(t, &memoryStore{}, store.Get(&oidcv1.OIDCConfig{})) + require.IsType(t, &memoryStore{}, store.Get(config.Chains[0].Filters[1].GetOidc())) + require.IsType(t, &memoryStore{}, store.Get(config.Chains[1].Filters[0].GetOidc())) + require.Equal(t, "http://redis1:6379", store.Get(config.Chains[2].Filters[0].GetOidc()).(*redisStore).url) + require.Equal(t, "http://redis2:6379", store.Get(config.Chains[3].Filters[0].GetOidc()).(*redisStore).url) +} diff --git a/internal/server/authz.go b/internal/server/authz.go index cfd9741..2d058a6 100644 --- a/internal/server/authz.go +++ b/internal/server/authz.go @@ -57,17 +57,19 @@ var ( // ExtAuthZFilter is an implementation of the Envoy AuthZ filter. type ExtAuthZFilter struct { - log telemetry.Logger - cfg *configv1.Config - jwks oidc.JWKSProvider + log telemetry.Logger + cfg *configv1.Config + jwks oidc.JWKSProvider + sessions *oidc.SessionStoreFactory } // NewExtAuthZFilter creates a new ExtAuthZFilter. -func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider) *ExtAuthZFilter { +func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) *ExtAuthZFilter { return &ExtAuthZFilter{ - log: internal.Logger(internal.Authz), - cfg: cfg, - jwks: jwks, + log: internal.Logger(internal.Authz), + cfg: cfg, + jwks: jwks, + sessions: sessions, } } @@ -117,7 +119,7 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re h = authz.NewMockHandler(ft.Mock) case *configv1.Filter_Oidc: // TODO(nacx): Check if the Oidc setting is enough or we have to pull the default Oidc settings - if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks); err != nil { + if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks, e.sessions); err != nil { return nil, err } } diff --git a/internal/server/authz_test.go b/internal/server/authz_test.go index c6ef672..49684fe 100644 --- a/internal/server/authz_test.go +++ b/internal/server/authz_test.go @@ -39,7 +39,7 @@ func TestUnmatchedRequests(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil) + e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil, nil) got, err := e.Check(context.Background(), &envoy.CheckRequest{}) require.NoError(t, err) require.Equal(t, int32(tt.want), got.Status.Code) @@ -61,7 +61,7 @@ func TestFiltersMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &configv1.Config{Chains: []*configv1.FilterChain{{Filters: tt.filters}}} - e := NewExtAuthZFilter(cfg, nil) + e := NewExtAuthZFilter(cfg, nil, nil) got, err := e.Check(context.Background(), &envoy.CheckRequest{}) require.NoError(t, err) @@ -91,7 +91,7 @@ func TestUseFirstMatchingChain(t *testing.T) { }, } - e := NewExtAuthZFilter(cfg, nil) + e := NewExtAuthZFilter(cfg, nil, nil) got, err := e.Check(context.Background(), header("match")) require.NoError(t, err) @@ -121,7 +121,7 @@ func TestMatch(t *testing.T) { } func TestGrpcNoChainsMatched(t *testing.T) { - e := NewExtAuthZFilter(&configv1.Config{}, nil) + e := NewExtAuthZFilter(&configv1.Config{}, nil, nil) s := NewTestServer(e.Register) go func() { require.NoError(t, s.Start()) }() t.Cleanup(s.Stop) @@ -274,7 +274,7 @@ func TestCheckTriggerRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := NewExtAuthZFilter(tt.config, nil) + e := NewExtAuthZFilter(tt.config, nil, nil) req := &envoy.CheckRequest{ Attributes: &envoy.AttributeContext{ Request: &envoy.AttributeContext_Request{