Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Load the session stores for all OIDC configurations #10

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ comment:
layout: "diff,files"

ignore:
- "api/**/*"
- "config/gen/**/*"

coverage:
status:
Expand All @@ -20,3 +20,8 @@ coverage:
target: auto
# allow a potential drop of up to 5%
threshold: 5%
patch:
default:
target: auto
only_pulls: true
threshold: 0%
4 changes: 3 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ func main() {
configFile = &internal.LocalConfigFile{}
logging = internal.NewLogSystem(log.New(), &configFile.Config)
jwks = oidc.NewJWKSProvider()
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks)
sessions = &oidc.SessionStoreFactory{Config: &configFile.Config}
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks, sessions)
authzServer = server.New(&configFile.Config, envoyAuthz.Register)
)

Expand All @@ -52,6 +53,7 @@ func main() {
logging, // set up the logging system
configLog, // log the configuration
jwks, // start the JWKS provider
sessions, // start the session store
authzServer, // start the server
&signal.Handler{}, // handle graceful termination
)
Expand Down
27 changes: 9 additions & 18 deletions internal/authz/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package authz

import (
"context"
"time"

envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
"github.com/tetratelabs/telemetry"
Expand All @@ -31,27 +30,19 @@ var _ Handler = (*oidcHandler)(nil)
// oidc handler is an implementation of the Handler interface that implements
// the OpenID connect protocol.
type oidcHandler struct {
log telemetry.Logger
config *oidcv1.OIDCConfig
store oidc.SessionStore
jwks oidc.JWKSProvider
log telemetry.Logger
config *oidcv1.OIDCConfig
jwks oidc.JWKSProvider
sessions *oidc.SessionStoreFactory
}

// NewOIDCHandler creates a new OIDC implementation of the Handler interface.
func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider) (Handler, error) {
// TODO(nacx): Read the redis store config to configure the redi store
// TODO(nacx): Properly lifecycle the session store
store := oidc.NewMemoryStore(
oidc.Clock{},
time.Duration(cfg.AbsoluteSessionTimeout),
time.Duration(cfg.IdleSessionTimeout),
)

func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) (Handler, error) {
return &oidcHandler{
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
store: store,
jwks: jwks,
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
jwks: jwks,
sessions: sessions,
}, nil
}

Expand Down
10 changes: 7 additions & 3 deletions internal/oidc/jwks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/stretchr/testify/require"
"github.com/tetratelabs/run"
"github.com/tetratelabs/telemetry"

oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)
Expand Down Expand Up @@ -143,8 +145,11 @@ func TestDynamicJWKSProvider(t *testing.T) {

newCache = func(t *testing.T) JWKSProvider {
cache := NewJWKSProvider()
go func() { require.NoError(t, cache.Serve()) }()
g := run.Group{Logger: telemetry.NoopLogger()}
g.Register(cache)
go func() { _ = g.Run() }()
t.Cleanup(cache.GracefulStop)

// Block until the cache is initialized
require.Eventually(t, func() bool {
return cache.cache != nil
Expand All @@ -160,8 +165,7 @@ func TestDynamicJWKSProvider(t *testing.T) {
config := &oidcv1.OIDCConfig{
JwksConfig: &oidcv1.OIDCConfig_JwksFetcher{
JwksFetcher: &oidcv1.OIDCConfig_JwksFetcherConfig{
JwksUri: server.URL + "/not-found",
PeriodicFetchIntervalSec: 1,
JwksUri: server.URL + "/not-found",
},
},
}
Expand Down
25 changes: 25 additions & 0 deletions internal/oidc/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package oidc

var _ SessionStore = (*redisStore)(nil)

// redisStore is an in-memory implementation of the SessionStore interface that stores
// the session data in a given Redis server.
type redisStore struct {
// TODO(nacx): Remove the interface embedding and implement it
SessionStore
url string
}
61 changes: 61 additions & 0 deletions internal/oidc/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@

package oidc

import (
"time"

"github.com/tetratelabs/run"

configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1"
oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)

// SessionStore is an interface for storing session data.
type SessionStore interface {
SetTokenResponse(sessionID string, tokenResponse *TokenResponse)
Expand All @@ -24,3 +33,55 @@ type SessionStore interface {
RemoveSession(sessionID string)
RemoveAllExpired()
}

var _ run.PreRunner = (*SessionStoreFactory)(nil)

// SessionStoreFactory is a factory for creating session stores.
// It uses the OIDC configuration to determine which store to use.
type SessionStoreFactory struct {
Config *configv1.Config

redis map[string]SessionStore
memory SessionStore
}

// Name implements run.Unit.
func (s *SessionStoreFactory) Name() string { return "OIDC session store factory" }

// PreRun initializes the stores that are defined in the configuration
func (s *SessionStoreFactory) PreRun() error {
s.redis = make(map[string]SessionStore)

for _, fc := range s.Config.Chains {
for _, f := range fc.Filters {
if f.GetOidc() == nil {
continue
}

if redisServer := f.GetOidc().GetRedisSessionStoreConfig().GetServerUri(); redisServer != "" {
// TODO(nacx): Initialize the Redis store
s.redis[redisServer] = &redisStore{url: redisServer}
} else if s.memory == nil { // Use a shared in-memory store for all OIDC configurations
s.memory = NewMemoryStore(
Clock{},
time.Duration(f.GetOidc().GetAbsoluteSessionTimeout()),
time.Duration(f.GetOidc().GetIdleSessionTimeout()),
)
}
}
}

return nil
}

// Get returns the appropriate session store for the given OIDC configuration.
func (s *SessionStoreFactory) Get(cfg *oidcv1.OIDCConfig) SessionStore {
if cfg == nil {
return nil
}
store, ok := s.redis[cfg.GetRedisSessionStoreConfig().GetServerUri()]
if !ok {
store = s.memory
}
return store
}
90 changes: 90 additions & 0 deletions internal/oidc/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package oidc

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tetratelabs/run"
"github.com/tetratelabs/telemetry"

configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1"
mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock"
oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
)

func TestSessionStoreFactory(t *testing.T) {
config := &configv1.Config{
ListenAddress: "0.0.0.0",
ListenPort: 8080,
LogLevel: "debug",
Threads: 1,
Chains: []*configv1.FilterChain{
{
Name: "memory1",
Filters: []*configv1.Filter{
{Type: &configv1.Filter_Mock{Mock: &mockv1.MockConfig{}}},
{Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}},
},
},
{
Name: "memory2",
Filters: []*configv1.Filter{
{Type: &configv1.Filter_Oidc{Oidc: &oidcv1.OIDCConfig{}}},
},
},
{
Name: "redis1",
Filters: []*configv1.Filter{
{
Type: &configv1.Filter_Oidc{
Oidc: &oidcv1.OIDCConfig{
RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis1:6379"},
},
},
},
},
},
{
Name: "redis2",
Filters: []*configv1.Filter{
{
Type: &configv1.Filter_Oidc{
Oidc: &oidcv1.OIDCConfig{
RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis2:6379"},
},
},
},
},
},
},
}

store := SessionStoreFactory{Config: config}
g := run.Group{Logger: telemetry.NoopLogger()}
g.Register(&store)
require.NoError(t, g.Run())

require.NotNil(t, store.memory)
require.Len(t, store.redis, 2)

require.Nil(t, store.Get(nil))
require.IsType(t, &memoryStore{}, store.Get(&oidcv1.OIDCConfig{}))
require.IsType(t, &memoryStore{}, store.Get(config.Chains[0].Filters[1].GetOidc()))
require.IsType(t, &memoryStore{}, store.Get(config.Chains[1].Filters[0].GetOidc()))
require.Equal(t, "http://redis1:6379", store.Get(config.Chains[2].Filters[0].GetOidc()).(*redisStore).url)
require.Equal(t, "http://redis2:6379", store.Get(config.Chains[3].Filters[0].GetOidc()).(*redisStore).url)
}
18 changes: 10 additions & 8 deletions internal/server/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,19 @@ var (

// ExtAuthZFilter is an implementation of the Envoy AuthZ filter.
type ExtAuthZFilter struct {
log telemetry.Logger
cfg *configv1.Config
jwks oidc.JWKSProvider
log telemetry.Logger
cfg *configv1.Config
jwks oidc.JWKSProvider
sessions *oidc.SessionStoreFactory
}

// NewExtAuthZFilter creates a new ExtAuthZFilter.
func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider) *ExtAuthZFilter {
func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider, sessions *oidc.SessionStoreFactory) *ExtAuthZFilter {
return &ExtAuthZFilter{
log: internal.Logger(internal.Authz),
cfg: cfg,
jwks: jwks,
log: internal.Logger(internal.Authz),
cfg: cfg,
jwks: jwks,
sessions: sessions,
}
}

Expand Down Expand Up @@ -117,7 +119,7 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re
h = authz.NewMockHandler(ft.Mock)
case *configv1.Filter_Oidc:
// TODO(nacx): Check if the Oidc setting is enough or we have to pull the default Oidc settings
if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks); err != nil {
if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks, e.sessions); err != nil {
return nil, err
}
}
Expand Down
10 changes: 5 additions & 5 deletions internal/server/authz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestUnmatchedRequests(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil)
e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil, nil)
got, err := e.Check(context.Background(), &envoy.CheckRequest{})
require.NoError(t, err)
require.Equal(t, int32(tt.want), got.Status.Code)
Expand All @@ -61,7 +61,7 @@ func TestFiltersMatch(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &configv1.Config{Chains: []*configv1.FilterChain{{Filters: tt.filters}}}
e := NewExtAuthZFilter(cfg, nil)
e := NewExtAuthZFilter(cfg, nil, nil)

got, err := e.Check(context.Background(), &envoy.CheckRequest{})
require.NoError(t, err)
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestUseFirstMatchingChain(t *testing.T) {
},
}

e := NewExtAuthZFilter(cfg, nil)
e := NewExtAuthZFilter(cfg, nil, nil)

got, err := e.Check(context.Background(), header("match"))
require.NoError(t, err)
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestMatch(t *testing.T) {
}

func TestGrpcNoChainsMatched(t *testing.T) {
e := NewExtAuthZFilter(&configv1.Config{}, nil)
e := NewExtAuthZFilter(&configv1.Config{}, nil, nil)
s := NewTestServer(e.Register)
go func() { require.NoError(t, s.Start()) }()
t.Cleanup(s.Stop)
Expand Down Expand Up @@ -274,7 +274,7 @@ func TestCheckTriggerRules(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewExtAuthZFilter(tt.config, nil)
e := NewExtAuthZFilter(tt.config, nil, nil)
req := &envoy.CheckRequest{
Attributes: &envoy.AttributeContext{
Request: &envoy.AttributeContext_Request{
Expand Down
Loading