Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Commit

Permalink
Load the session stores for all OIDC configurations (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
nacx authored Feb 14, 2024
1 parent 58dfe7b commit 9d0e3d3
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 36 deletions.
7 changes: 6 additions & 1 deletion .github/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ comment:
layout: "diff,files"

ignore:
- "api/**/*"
- "config/gen/**/*"

coverage:
status:
Expand All @@ -20,3 +20,8 @@ coverage:
target: auto
# allow a potential drop of up to 5%
threshold: 5%
patch:
default:
target: auto
only_pulls: true
threshold: 0%
4 changes: 3 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ func main() {
configFile = &internal.LocalConfigFile{}
logging = internal.NewLogSystem(log.New(), &configFile.Config)
jwks = oidc.NewJWKSProvider()
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks)
sessions = &oidc.SessionStoreFactory{Config: &configFile.Config}
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks, sessions)
authzServer = server.New(&configFile.Config, envoyAuthz.Register)
)

Expand All @@ -52,6 +53,7 @@ func main() {
logging, // set up the logging system
configLog, // log the configuration
jwks, // start the JWKS provider
sessions, // start the session store
authzServer, // start the server
&signal.Handler{}, // handle graceful termination
)
Expand Down
27 changes: 9 additions & 18 deletions internal/authz/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package authz

import (
"context"
"time"

envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/tetratelabs/telemetry"
Expand All @@ -31,27 +30,19 @@ var _ Handler = (*oidcHandler)(nil)
// oidc handler is an implementation of the Handler interface that implements
// the OpenID connect protocol.
type oidcHandler struct {
log telemetry.Logger
config *oidcv1.OIDCConfig
store oidc.SessionStore
jwks oidc.JWKSProvider
log telemetry.Logger
config *oidcv1.OIDCConfig
jwks oidc.JWKSProvider
sessions *oidc.SessionStoreFactory
}

// NewOIDCHandler creates a new OIDC implementation of the Handler interface.
func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider) (Handler, error) {
// TODO(nacx): Read the redis store config to configure the redi store
// TODO(nacx): Properly lifecycle the session store
store := oidc.NewMemoryStore(
oidc.Clock{},
time.Duration(cfg.AbsoluteSessionTimeout),
time.Duration(cfg.IdleSessionTimeout),
)

func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) (Handler, error) {
return &oidcHandler{
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
store: store,
jwks: jwks,
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
jwks: jwks,
sessions: sessions,
}, nil
}

Expand Down
10 changes: 7 additions & 3 deletions internal/oidc/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/stretchr/testify/require"
"github.com/tetratelabs/run"
"github.com/tetratelabs/telemetry"

oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)
Expand Down Expand Up @@ -143,8 +145,11 @@ func TestDynamicJWKSProvider(t *testing.T) {

newCache = func(t *testing.T) JWKSProvider {
cache := NewJWKSProvider()
go func() { require.NoError(t, cache.Serve()) }()
g := run.Group{Logger: telemetry.NoopLogger()}
g.Register(cache)
go func() { _ = g.Run() }()
t.Cleanup(cache.GracefulStop)

// Block until the cache is initialized
require.Eventually(t, func() bool {
return cache.cache != nil
Expand All @@ -160,8 +165,7 @@ func TestDynamicJWKSProvider(t *testing.T) {
config := &oidcv1.OIDCConfig{
JwksConfig: &oidcv1.OIDCConfig_JwksFetcher{
JwksFetcher: &oidcv1.OIDCConfig_JwksFetcherConfig{
JwksUri: server.URL + "/not-found",
PeriodicFetchIntervalSec: 1,
JwksUri: server.URL + "/not-found",
},
},
}
Expand Down
25 changes: 25 additions & 0 deletions internal/oidc/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package oidc

var _ SessionStore = (*redisStore)(nil)

// redisStore is an in-memory implementation of the SessionStore interface that stores
// the session data in a given Redis server.
type redisStore struct {
// TODO(nacx): Remove the interface embedding and implement it
SessionStore
url string
}
61 changes: 61 additions & 0 deletions internal/oidc/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@

package oidc

import (
"time"

"github.com/tetratelabs/run"

configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1"
oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)

// SessionStore is an interface for storing session data.
type SessionStore interface {
SetTokenResponse(sessionID string, tokenResponse *TokenResponse)
Expand All @@ -24,3 +33,55 @@ type SessionStore interface {
RemoveSession(sessionID string)
RemoveAllExpired()
}

var _ run.PreRunner = (*SessionStoreFactory)(nil)

// SessionStoreFactory is a factory for creating session stores.
// It uses the OIDC configuration to determine which store to use.
type SessionStoreFactory struct {
Config *configv1.Config

redis map[string]SessionStore
memory SessionStore
}

// Name implements run.Unit.
func (s *SessionStoreFactory) Name() string { return "OIDC session store factory" }

// PreRun initializes the stores that are defined in the configuration
func (s *SessionStoreFactory) PreRun() error {
s.redis = make(map[string]SessionStore)

for _, fc := range s.Config.Chains {
for _, f := range fc.Filters {
if f.GetOidc() == nil {
continue
}

if redisServer := f.GetOidc().GetRedisSessionStoreConfig().GetServerUri(); redisServer != "" {
// TODO(nacx): Initialize the Redis store
s.redis[redisServer] = &redisStore{url: redisServer}
} else if s.memory == nil { // Use a shared in-memory store for all OIDC configurations
s.memory = NewMemoryStore(
Clock{},
time.Duration(f.GetOidc().GetAbsoluteSessionTimeout()),
time.Duration(f.GetOidc().GetIdleSessionTimeout()),
)
}
}
}

return nil
}

// Get returns the appropriate session store for the given OIDC configuration.
func (s *SessionStoreFactory) Get(cfg *oidcv1.OIDCConfig) SessionStore {
if cfg == nil {
return nil
}
store, ok := s.redis[cfg.GetRedisSessionStoreConfig().GetServerUri()]
if !ok {
store = s.memory
}
return store
}
90 changes: 90 additions & 0 deletions internal/oidc/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package oidc

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tetratelabs/run"
"github.com/tetratelabs/telemetry"

configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1"
mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock"
oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)

func TestSessionStoreFactory(t *testing.T) {
config := &configv1.Config{
ListenAddress: "0.0.0.0",
ListenPort: 8080,
LogLevel: "debug",
Threads: 1,
Chains: []*configv1.FilterChain{
{
Name: "memory1",
Filters: []*configv1.Filter{
{Type: &configv1.Filter_Mock{Mock: &mockv1.MockConfig{}}},
{Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}},
},
},
{
Name: "memory2",
Filters: []*configv1.Filter{
{Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}},
},
},
{
Name: "redis1",
Filters: []*configv1.Filter{
{
Type: &configv1.Filter_Oidc{
Oidc: &oidcv1.OIDCConfig{
RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis1:6379"},
},
},
},
},
},
{
Name: "redis2",
Filters: []*configv1.Filter{
{
Type: &configv1.Filter_Oidc{
Oidc: &oidcv1.OIDCConfig{
RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis2:6379"},
},
},
},
},
},
},
}

store := SessionStoreFactory{Config: config}
g := run.Group{Logger: telemetry.NoopLogger()}
g.Register(&store)
require.NoError(t, g.Run())

require.NotNil(t, store.memory)
require.Len(t, store.redis, 2)

require.Nil(t, store.Get(nil))
require.IsType(t, &memoryStore{}, store.Get(&oidcv1.OIDCConfig{}))
require.IsType(t, &memoryStore{}, store.Get(config.Chains[0].Filters[1].GetOidc()))
require.IsType(t, &memoryStore{}, store.Get(config.Chains[1].Filters[0].GetOidc()))
require.Equal(t, "http://redis1:6379", store.Get(config.Chains[2].Filters[0].GetOidc()).(*redisStore).url)
require.Equal(t, "http://redis2:6379", store.Get(config.Chains[3].Filters[0].GetOidc()).(*redisStore).url)
}
18 changes: 10 additions & 8 deletions internal/server/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,19 @@ var (

// ExtAuthZFilter is an implementation of the Envoy AuthZ filter.
type ExtAuthZFilter struct {
log telemetry.Logger
cfg *configv1.Config
jwks oidc.JWKSProvider
log telemetry.Logger
cfg *configv1.Config
jwks oidc.JWKSProvider
sessions *oidc.SessionStoreFactory
}

// NewExtAuthZFilter creates a new ExtAuthZFilter.
func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider) *ExtAuthZFilter {
func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) *ExtAuthZFilter {
return &ExtAuthZFilter{
log: internal.Logger(internal.Authz),
cfg: cfg,
jwks: jwks,
log: internal.Logger(internal.Authz),
cfg: cfg,
jwks: jwks,
sessions: sessions,
}
}

Expand Down Expand Up @@ -117,7 +119,7 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re
h = authz.NewMockHandler(ft.Mock)
case *configv1.Filter_Oidc:
// TODO(nacx): Check if the Oidc setting is enough or we have to pull the default Oidc settings
if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks); err != nil {
if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks, e.sessions); err != nil {
return nil, err
}
}
Expand Down
10 changes: 5 additions & 5 deletions internal/server/authz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestUnmatchedRequests(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil)
e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil, nil)
got, err := e.Check(context.Background(), &envoy.CheckRequest{})
require.NoError(t, err)
require.Equal(t, int32(tt.want), got.Status.Code)
Expand All @@ -61,7 +61,7 @@ func TestFiltersMatch(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &configv1.Config{Chains: []*configv1.FilterChain{{Filters: tt.filters}}}
e := NewExtAuthZFilter(cfg, nil)
e := NewExtAuthZFilter(cfg, nil, nil)

got, err := e.Check(context.Background(), &envoy.CheckRequest{})
require.NoError(t, err)
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestUseFirstMatchingChain(t *testing.T) {
},
}

e := NewExtAuthZFilter(cfg, nil)
e := NewExtAuthZFilter(cfg, nil, nil)

got, err := e.Check(context.Background(), header("match"))
require.NoError(t, err)
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestMatch(t *testing.T) {
}

func TestGrpcNoChainsMatched(t *testing.T) {
e := NewExtAuthZFilter(&configv1.Config{}, nil)
e := NewExtAuthZFilter(&configv1.Config{}, nil, nil)
s := NewTestServer(e.Register)
go func() { require.NoError(t, s.Start()) }()
t.Cleanup(s.Stop)
Expand Down Expand Up @@ -274,7 +274,7 @@ func TestCheckTriggerRules(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewExtAuthZFilter(tt.config, nil)
e := NewExtAuthZFilter(tt.config, nil, nil)
req := &envoy.CheckRequest{
Attributes: &envoy.AttributeContext{
Request: &envoy.AttributeContext_Request{
Expand Down

0 comments on commit 9d0e3d3

Please sign in to comment.