Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add Interceptor for RateLimit
Browse files Browse the repository at this point in the history
  • Loading branch information
LaPetiteSouris committed May 7, 2023
1 parent 2195181 commit b4f04b6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 39 deletions.
6 changes: 6 additions & 0 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
auth.AuthenticationLoggingInterceptor,
middlewareInterceptors,
)
if cfg.Security.RateLimit.Enabled {
rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration)
rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors)
}
} else {
logger.Infof(ctx, "Creating gRPC server without authentication")
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor)
Expand Down Expand Up @@ -257,6 +262,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
}

oauth2ResourceServer = oauth2Provider

} else {
oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL)
if err != nil {
Expand Down
52 changes: 45 additions & 7 deletions plugins/rate_limit.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
package plugins

import (
"context"
"errors"
"fmt"
"sync"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type RateLimitError error
type RateLimitExceeded error

// define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map
type accessRecords struct {
limiter *rate.Limiter
lastAccess time.Time
}

type Limiter struct {
type LimiterStore struct {
accessPerUser map[string]*accessRecords
mutex *sync.Mutex
requestPerSec int
Expand All @@ -27,7 +33,7 @@ type Limiter struct {
// define a function named Allow that takes userID and returns RateLimitError
// the function check if the user is in the map, if not, create a new accessRecords for the user
// then it check if the user can access the resource, if not, return RateLimitError
func (l *Limiter) Allow(userID string) error {
func (l *LimiterStore) Allow(userID string) error {
l.mutex.Lock()
defer l.mutex.Unlock()
if _, ok := l.accessPerUser[userID]; !ok {
Expand All @@ -38,13 +44,13 @@ func (l *Limiter) Allow(userID string) error {
}

if !l.accessPerUser[userID].limiter.Allow() {
return RateLimitError(fmt.Errorf("rate limit exceeded"))
return RateLimitExceeded(fmt.Errorf("rate limit exceeded"))
}

return nil
}

func (l *Limiter) clean() {
func (l *LimiterStore) clean() {
l.mutex.Lock()
defer l.mutex.Unlock()
for userID, accessRecord := range l.accessPerUser {
Expand All @@ -54,8 +60,8 @@ func (l *Limiter) clean() {
}
}

func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *Limiter {
l := &Limiter{
func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore {
l := &LimiterStore{
accessPerUser: make(map[string]*accessRecords),
mutex: &sync.Mutex{},
requestPerSec: requestPerSec,
Expand All @@ -72,3 +78,35 @@ func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Durat

return l
}

type RateLimiter struct {
limiter *LimiterStore
}

func (r *RateLimiter) Limit(ctx context.Context) error {
IdenCtx := auth.IdentityContextFromContext(ctx)
if IdenCtx.IsEmpty() {
return errors.New("no identity context found")
}
userID := IdenCtx.UserID()
if err := r.limiter.Allow(userID); err != nil {
return err
}
return nil
}

func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter {
limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval)
return &RateLimiter{limiter: limiter}
}

func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
}

return handler(ctx, req)
}
}
103 changes: 71 additions & 32 deletions plugins/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -1,56 +1,95 @@
package plugins

import (
"context"
"testing"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"github.com/stretchr/testify/assert"
)

func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rl)
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NotNil(t, rlStore)
}

func TestLimiter_Allow(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
assert.NoError(t, rl.Allow("hello"))
// assert error type is RateLimitError
assert.Error(t, rl.Allow("hello"))
func TestLimiterAllow(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
assert.NoError(t, rl.Allow("hello"))
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiter_AllowBurst(t *testing.T) {
rl := NewRateLimiter(1, 2, time.Second)
assert.NoError(t, rl.Allow("hello"))
assert.NoError(t, rl.Allow("hello"))
assert.Error(t, rl.Allow("hello"))
assert.NoError(t, rl.Allow("world"))
func TestLimiterAllowBurst(t *testing.T) {
rlStore := newRateLimitStore(1, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("world"))
}

func TestLimiter_Clean(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
assert.NoError(t, rl.Allow("hello"))
assert.Error(t, rl.Allow("hello"))
func TestLimiterClean(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
rl.clean()
assert.NoError(t, rl.Allow("hello"))
rlStore.clean()
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiter_AllowOnMultipleRequests(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
assert.NoError(t, rl.Allow("a"))
assert.NoError(t, rl.Allow("b"))
assert.NoError(t, rl.Allow("c"))
assert.Error(t, rl.Allow("a"))
assert.Error(t, rl.Allow("b"))
func TestLimiterAllowOnMultipleRequests(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
assert.Error(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("b"))

time.Sleep(time.Second)

assert.NoError(t, rl.Allow("a"))
assert.Error(t, rl.Allow("a"))
assert.NoError(t, rl.Allow("b"))
assert.Error(t, rl.Allow("b"))
assert.NoError(t, rl.Allow("c"))
assert.NoError(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.Error(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
}

func TestRateLimiterLimitPass(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)

ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

}

func TestRateLimiterLimitStop(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)
ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

err = rateLimit.Limit(ctx)
assert.Error(t, err)

}

func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

ctx := context.TODO()

err := rateLimit.Limit(ctx)
assert.Error(t, err)
}

0 comments on commit b4f04b6

Please sign in to comment.