From 271bc9f859bdab531c29cf2edceaf738a4f99da9 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Sun, 7 May 2023 16:36:26 +0200 Subject: [PATCH 1/8] Add rate_limiter in plugins Signed-off-by: TungHoang --- pkg/config/config.go | 10 +++++- plugins/rate_limit.go | 74 ++++++++++++++++++++++++++++++++++++++ plugins/rate_limit_test.go | 56 +++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 plugins/rate_limit.go create mode 100644 plugins/rate_limit_test.go diff --git a/pkg/config/config.go b/pkg/config/config.go index 28996bd23..245b781c5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -80,7 +80,8 @@ type ServerSecurityOptions struct { // These are the Access-Control-Request-Headers that the server will respond to. // By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type. // DeprecatedUser this setting to add any additional headers which are needed - AllowedHeaders []string `json:"allowedHeaders"` + AllowedHeaders []string `json:"allowedHeaders"` + RateLimit RateLimitOptions `json:"rateLimitOptions"` } type SslOptions struct { @@ -88,6 +89,13 @@ type SslOptions struct { KeyFile string `json:"keyFile"` } +// declare RateLimitConfig +type RateLimitOptions struct { + RequestsPerSecond int + BurstSize int + CleanupInterval config.Duration +} + var defaultServerConfig = &ServerConfig{ HTTPPort: 8088, Security: ServerSecurityOptions{ diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go new file mode 100644 index 000000000..283461cd8 --- /dev/null +++ b/plugins/rate_limit.go @@ -0,0 +1,74 @@ +package plugins + +import ( + "fmt" + "sync" + "time" + + "golang.org/x/time/rate" +) + +type RateLimitError error + +// define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map +type accessRecords struct { + limiter *rate.Limiter + lastAccess time.Time +} + +type Limiter struct { + accessPerUser map[string]*accessRecords + mutex *sync.Mutex + requestPerSec int + burstSize int + cleanupInterval time.Duration +} + +// define a function named Allow that takes userID and returns RateLimitError +// the function check if the user is in the map, if not, create a new accessRecords for the user +// then it check if the user can access the resource, if not, return RateLimitError +func (l *Limiter) Allow(userID string) RateLimitError { + l.mutex.Lock() + defer l.mutex.Unlock() + if _, ok := l.accessPerUser[userID]; !ok { + l.accessPerUser[userID] = &accessRecords{ + lastAccess: time.Now(), + limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), + } + } + + if !l.accessPerUser[userID].limiter.Allow() { + return RateLimitError(fmt.Errorf("rate limit exceeded")) + } + + return nil +} + +func (l *Limiter) clean() { + l.mutex.Lock() + defer l.mutex.Unlock() + for userID, accessRecord := range l.accessPerUser { + if time.Since(accessRecord.lastAccess) > l.cleanupInterval { + delete(l.accessPerUser, userID) + } + } +} + +func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *Limiter { + l := &Limiter{ + accessPerUser: make(map[string]*accessRecords), + mutex: &sync.Mutex{}, + requestPerSec: requestPerSec, + burstSize: burstSize, + cleanupInterval: cleanupInterval, + } + + go func() { + for { + time.Sleep(l.cleanupInterval) + l.clean() + } + }() + + return l +} diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go new file mode 100644 index 000000000..0b045ed44 --- /dev/null +++ b/plugins/rate_limit_test.go @@ -0,0 +1,56 @@ +package plugins + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rl) +} + +func TestLimiter_Allow(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + assert.NoError(t, rl.Allow("hello")) + // assert error type is RateLimitError + assert.Error(t, rl.Allow("hello")) + time.Sleep(time.Second) + assert.NoError(t, rl.Allow("hello")) +} + +func TestLimiter_AllowBurst(t *testing.T) { + rl := NewRateLimiter(1, 2, time.Second) + assert.NoError(t, rl.Allow("hello")) + assert.NoError(t, rl.Allow("hello")) + assert.Error(t, rl.Allow("hello")) + assert.NoError(t, rl.Allow("world")) +} + +func TestLimiter_Clean(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + assert.NoError(t, rl.Allow("hello")) + assert.Error(t, rl.Allow("hello")) + time.Sleep(time.Second) + rl.clean() + assert.NoError(t, rl.Allow("hello")) +} + +func TestLimiter_AllowOnMultipleRequests(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + assert.NoError(t, rl.Allow("a")) + assert.NoError(t, rl.Allow("b")) + assert.NoError(t, rl.Allow("c")) + assert.Error(t, rl.Allow("a")) + assert.Error(t, rl.Allow("b")) + + time.Sleep(time.Second) + + assert.NoError(t, rl.Allow("a")) + assert.Error(t, rl.Allow("a")) + assert.NoError(t, rl.Allow("b")) + assert.Error(t, rl.Allow("b")) + assert.NoError(t, rl.Allow("c")) +} From 9f8daf7262512db4b78ce2828ad4fdcbdf882a73 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Sun, 7 May 2023 16:42:01 +0200 Subject: [PATCH 2/8] Add configs for rate limit Signed-off-by: TungHoang --- pkg/config/config.go | 7 ++++--- plugins/rate_limit.go | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 245b781c5..a855f75dd 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -91,9 +91,10 @@ type SslOptions struct { // declare RateLimitConfig type RateLimitOptions struct { - RequestsPerSecond int - BurstSize int - CleanupInterval config.Duration + Enabled bool `json:"enabled"` + RequestsPerSecond int `json:"requestsPerSecond"` + BurstSize int `json:"burstSize"` + CleanupInterval config.Duration `json:"cleanupInterval"` } var defaultServerConfig = &ServerConfig{ diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 283461cd8..0197c0cf3 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -27,7 +27,7 @@ type Limiter struct { // define a function named Allow that takes userID and returns RateLimitError // the function check if the user is in the map, if not, create a new accessRecords for the user // then it check if the user can access the resource, if not, return RateLimitError -func (l *Limiter) Allow(userID string) RateLimitError { +func (l *Limiter) Allow(userID string) error { l.mutex.Lock() defer l.mutex.Unlock() if _, ok := l.accessPerUser[userID]; !ok { From 2b7fb81e3aa49ff5548e9bb5a3363e18c4e38a41 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Sun, 7 May 2023 21:37:29 +0200 Subject: [PATCH 3/8] Add Interceptor for RateLimit Signed-off-by: TungHoang --- pkg/server/service.go | 6 +++ plugins/rate_limit.go | 52 ++++++++++++++++--- plugins/rate_limit_test.go | 103 +++++++++++++++++++++++++------------ 3 files changed, 122 insertions(+), 39 deletions(-) diff --git a/pkg/server/service.go b/pkg/server/service.go index f3b27416f..8d226e7b3 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -90,6 +90,11 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c auth.AuthenticationLoggingInterceptor, middlewareInterceptors, ) + if cfg.Security.RateLimit.Enabled { + rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration) + rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter) + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors) + } } else { logger.Infof(ctx, "Creating gRPC server without authentication") chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) @@ -257,6 +262,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } oauth2ResourceServer = oauth2Provider + } else { oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) if err != nil { diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 0197c0cf3..d32cdf985 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -1,14 +1,20 @@ package plugins import ( + "context" + "errors" "fmt" "sync" "time" + auth "github.com/flyteorg/flyteadmin/auth" "golang.org/x/time/rate" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -type RateLimitError error +type RateLimitExceeded error // define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map type accessRecords struct { @@ -16,7 +22,7 @@ type accessRecords struct { lastAccess time.Time } -type Limiter struct { +type LimiterStore struct { accessPerUser map[string]*accessRecords mutex *sync.Mutex requestPerSec int @@ -27,7 +33,7 @@ type Limiter struct { // define a function named Allow that takes userID and returns RateLimitError // the function check if the user is in the map, if not, create a new accessRecords for the user // then it check if the user can access the resource, if not, return RateLimitError -func (l *Limiter) Allow(userID string) error { +func (l *LimiterStore) Allow(userID string) error { l.mutex.Lock() defer l.mutex.Unlock() if _, ok := l.accessPerUser[userID]; !ok { @@ -38,13 +44,13 @@ func (l *Limiter) Allow(userID string) error { } if !l.accessPerUser[userID].limiter.Allow() { - return RateLimitError(fmt.Errorf("rate limit exceeded")) + return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) } return nil } -func (l *Limiter) clean() { +func (l *LimiterStore) clean() { l.mutex.Lock() defer l.mutex.Unlock() for userID, accessRecord := range l.accessPerUser { @@ -54,8 +60,8 @@ func (l *Limiter) clean() { } } -func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *Limiter { - l := &Limiter{ +func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { + l := &LimiterStore{ accessPerUser: make(map[string]*accessRecords), mutex: &sync.Mutex{}, requestPerSec: requestPerSec, @@ -72,3 +78,35 @@ func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Durat return l } + +type RateLimiter struct { + limiter *LimiterStore +} + +func (r *RateLimiter) Limit(ctx context.Context) error { + IdenCtx := auth.IdentityContextFromContext(ctx) + if IdenCtx.IsEmpty() { + return errors.New("no identity context found") + } + userID := IdenCtx.UserID() + if err := r.limiter.Allow(userID); err != nil { + return err + } + return nil +} + +func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter { + limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval) + return &RateLimiter{limiter: limiter} +} + +func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + if err := limiter.Limit(ctx); err != nil { + return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded") + } + + return handler(ctx, req) + } +} diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index 0b045ed44..87907b2f1 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -1,56 +1,95 @@ package plugins import ( + "context" "testing" "time" + auth "github.com/flyteorg/flyteadmin/auth" "github.com/stretchr/testify/assert" ) func TestNewRateLimiter(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NotNil(t, rl) + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NotNil(t, rlStore) } -func TestLimiter_Allow(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("hello")) - // assert error type is RateLimitError - assert.Error(t, rl.Allow("hello")) +func TestLimiterAllow(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) - assert.NoError(t, rl.Allow("hello")) + assert.NoError(t, rlStore.Allow("hello")) } -func TestLimiter_AllowBurst(t *testing.T) { - rl := NewRateLimiter(1, 2, time.Second) - assert.NoError(t, rl.Allow("hello")) - assert.NoError(t, rl.Allow("hello")) - assert.Error(t, rl.Allow("hello")) - assert.NoError(t, rl.Allow("world")) +func TestLimiterAllowBurst(t *testing.T) { + rlStore := newRateLimitStore(1, 2, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("world")) } -func TestLimiter_Clean(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("hello")) - assert.Error(t, rl.Allow("hello")) +func TestLimiterClean(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) - rl.clean() - assert.NoError(t, rl.Allow("hello")) + rlStore.clean() + assert.NoError(t, rlStore.Allow("hello")) } -func TestLimiter_AllowOnMultipleRequests(t *testing.T) { - rl := NewRateLimiter(1, 1, time.Second) - assert.NoError(t, rl.Allow("a")) - assert.NoError(t, rl.Allow("b")) - assert.NoError(t, rl.Allow("c")) - assert.Error(t, rl.Allow("a")) - assert.Error(t, rl.Allow("b")) +func TestLimiterAllowOnMultipleRequests(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) + assert.Error(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("b")) time.Sleep(time.Second) - assert.NoError(t, rl.Allow("a")) - assert.Error(t, rl.Allow("a")) - assert.NoError(t, rl.Allow("b")) - assert.Error(t, rl.Allow("b")) - assert.NoError(t, rl.Allow("c")) + assert.NoError(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.Error(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) +} + +func TestRateLimiterLimitPass(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + +} + +func TestRateLimiterLimitStop(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + + err = rateLimit.Limit(ctx) + assert.Error(t, err) + +} + +func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + ctx := context.TODO() + + err := rateLimit.Limit(ctx) + assert.Error(t, err) } From 808ec458aaa1b46bf4024537c67a429de08c2f49 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Tue, 9 May 2023 21:24:45 +0200 Subject: [PATCH 4/8] Add details docs for configs Signed-off-by: TungHoang --- pkg/config/config.go | 12 +++--- pkg/config/serverconfig_flags.go | 4 ++ pkg/config/serverconfig_flags_test.go | 56 +++++++++++++++++++++++++++ pkg/server/service.go | 3 ++ plugins/rate_limit.go | 8 ++-- 5 files changed, 73 insertions(+), 10 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index a855f75dd..d86c07333 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -81,7 +81,7 @@ type ServerSecurityOptions struct { // By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type. // DeprecatedUser this setting to add any additional headers which are needed AllowedHeaders []string `json:"allowedHeaders"` - RateLimit RateLimitOptions `json:"rateLimitOptions"` + RateLimit RateLimitOptions `json:"rateLimit"` } type SslOptions struct { @@ -89,12 +89,12 @@ type SslOptions struct { KeyFile string `json:"keyFile"` } -// declare RateLimitConfig +// RateLimitOptions is a type to hold rate limit configuration options. type RateLimitOptions struct { - Enabled bool `json:"enabled"` - RequestsPerSecond int `json:"requestsPerSecond"` - BurstSize int `json:"burstSize"` - CleanupInterval config.Duration `json:"cleanupInterval"` + Enabled bool `json:"enabled" pflag:",Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm."` + RequestsPerSecond int `json:"requestsPerSecond" pflag:",The number of requests allowed per second."` + BurstSize int `json:"burstSize" pflag:",The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens."` + CleanupInterval config.Duration `json:"cleanupInterval" pflag:",The interval at which the rate limiter cleans up entries that have not been used for a certain period of time."` } var defaultServerConfig = &ServerConfig{ diff --git a/pkg/config/serverconfig_flags.go b/pkg/config/serverconfig_flags.go index c37a82603..fe7e00e00 100755 --- a/pkg/config/serverconfig_flags.go +++ b/pkg/config/serverconfig_flags.go @@ -63,6 +63,10 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.allowCors"), defaultServerConfig.Security.AllowCors, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedOrigins"), defaultServerConfig.Security.AllowedOrigins, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedHeaders"), defaultServerConfig.Security.AllowedHeaders, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.rateLimit.enabled"), defaultServerConfig.Security.RateLimit.Enabled, "Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.requestsPerSecond"), defaultServerConfig.Security.RateLimit.RequestsPerSecond, "The number of requests allowed per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.burstSize"), defaultServerConfig.Security.RateLimit.BurstSize, "The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "security.rateLimit.cleanupInterval"), defaultServerConfig.Security.RateLimit.CleanupInterval.String(), "The interval at which the rate limiter cleans up entries that have not been used for a certain period of time.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.port"), defaultServerConfig.GrpcConfig.Port, "On which grpc port to serve admin") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpc.serverReflection"), defaultServerConfig.GrpcConfig.ServerReflection, "Enable GRPC Server Reflection") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.maxMessageSizeBytes"), defaultServerConfig.GrpcConfig.MaxMessageSizeBytes, "The max size in bytes for incoming gRPC messages") diff --git a/pkg/config/serverconfig_flags_test.go b/pkg/config/serverconfig_flags_test.go index b16e0416d..f5eaa0752 100755 --- a/pkg/config/serverconfig_flags_test.go +++ b/pkg/config/serverconfig_flags_test.go @@ -281,6 +281,62 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_security.rateLimit.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.enabled", testValue) + if vBool, err := cmdFlags.GetBool("security.rateLimit.enabled"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vBool), &actual.Security.RateLimit.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.requestsPerSecond", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.requestsPerSecond", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.requestsPerSecond"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.RequestsPerSecond) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.burstSize", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.burstSize", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.burstSize"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.BurstSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.cleanupInterval", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultServerConfig.Security.RateLimit.CleanupInterval.String() + + cmdFlags.Set("security.rateLimit.cleanupInterval", testValue) + if vString, err := cmdFlags.GetString("security.rateLimit.cleanupInterval"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vString), &actual.Security.RateLimit.CleanupInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_grpc.port", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/server/service.go b/pkg/server/service.go index 8d226e7b3..787eea209 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -98,6 +98,9 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c } else { logger.Infof(ctx, "Creating gRPC server without authentication") chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) + if cfg.Security.RateLimit.Enabled { + logger.Warningf(ctx, "Rate limit is enabled but auth is not") + } } serverOpts := []grpc.ServerOption{ diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index d32cdf985..07800b728 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -16,13 +16,15 @@ import ( type RateLimitExceeded error -// define a struct that contains a map of rate limiters, and a time stamp of last access and a mutex to protect the map +// accessRecords stores the rate limiter and the last access time type accessRecords struct { limiter *rate.Limiter lastAccess time.Time } +// LimiterStore stores the access records for each user type LimiterStore struct { + // accessPerUser is a synchronized map of userID to accessRecords accessPerUser map[string]*accessRecords mutex *sync.Mutex requestPerSec int @@ -30,9 +32,7 @@ type LimiterStore struct { cleanupInterval time.Duration } -// define a function named Allow that takes userID and returns RateLimitError -// the function check if the user is in the map, if not, create a new accessRecords for the user -// then it check if the user can access the resource, if not, return RateLimitError +// Allow takes a userID and returns an error if the user has exceeded the rate limit func (l *LimiterStore) Allow(userID string) error { l.mutex.Lock() defer l.mutex.Unlock() From 260e71d3700f38dd2b2529bb9b6b38948b876c39 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Wed, 10 May 2023 11:45:57 +0200 Subject: [PATCH 5/8] Add lastAccess timestamp test for RateLimiter Signed-off-by: TungHoang --- plugins/rate_limit.go | 1 + plugins/rate_limit_test.go | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 07800b728..90c85f39f 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -42,6 +42,7 @@ func (l *LimiterStore) Allow(userID string) error { limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), } } + l.accessPerUser[userID].lastAccess = time.Now() if !l.accessPerUser[userID].limiter.Allow() { return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index 87907b2f1..dfddb851d 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -93,3 +93,12 @@ func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { err := rateLimit.Limit(ctx) assert.Error(t, err) } + +func TestRateLimiterUpdateLastAccessTime(t *testing.T) { + rlStore := newRateLimitStore(2, 2, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + lastAccessTime := rlStore.accessPerUser["hello"].lastAccess + assert.NoError(t, rlStore.Allow("hello")) + newAccessTime := rlStore.accessPerUser["hello"].lastAccess + assert.True(t, newAccessTime.After(lastAccessTime)) +} From e9585fb4681a9276da872e57c39f0d2938392349 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Wed, 10 May 2023 13:44:29 +0200 Subject: [PATCH 6/8] Add timestamp unit test for RateLimiter Signed-off-by: TungHoang --- plugins/rate_limit_test.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index dfddb851d..b06bc5994 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -97,8 +97,13 @@ func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { func TestRateLimiterUpdateLastAccessTime(t *testing.T) { rlStore := newRateLimitStore(2, 2, time.Second) assert.NoError(t, rlStore.Allow("hello")) - lastAccessTime := rlStore.accessPerUser["hello"].lastAccess + firstAccessTime := rlStore.accessPerUser["hello"].lastAccess assert.NoError(t, rlStore.Allow("hello")) - newAccessTime := rlStore.accessPerUser["hello"].lastAccess - assert.True(t, newAccessTime.After(lastAccessTime)) + secondAccessTime := rlStore.accessPerUser["hello"].lastAccess + assert.True(t, secondAccessTime.After(firstAccessTime)) + // Verify that the last access time is updated even when user is rate limited + assert.Error(t, rlStore.Allow("hello")) + thirdAccessTime := rlStore.accessPerUser["hello"].lastAccess + assert.True(t, thirdAccessTime.After(secondAccessTime)) + } From 6061035f5caf659e8047b213186f837e31bf43e7 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Wed, 10 May 2023 17:41:10 +0200 Subject: [PATCH 7/8] Use sync.Map to avoid contention on RateLimiter Signed-off-by: TungHoang --- plugins/rate_limit.go | 43 ++++++++++++++++++++++---------------- plugins/rate_limit_test.go | 25 ++++++++++++++++++---- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 90c85f39f..4c81657de 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -20,50 +20,56 @@ type RateLimitExceeded error type accessRecords struct { limiter *rate.Limiter lastAccess time.Time + mutex *sync.Mutex } // LimiterStore stores the access records for each user type LimiterStore struct { // accessPerUser is a synchronized map of userID to accessRecords - accessPerUser map[string]*accessRecords - mutex *sync.Mutex + accessPerUser *sync.Map requestPerSec int burstSize int + mutex *sync.Mutex cleanupInterval time.Duration } // Allow takes a userID and returns an error if the user has exceeded the rate limit func (l *LimiterStore) Allow(userID string) error { - l.mutex.Lock() - defer l.mutex.Unlock() - if _, ok := l.accessPerUser[userID]; !ok { - l.accessPerUser[userID] = &accessRecords{ - lastAccess: time.Now(), - limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), - } - } - l.accessPerUser[userID].lastAccess = time.Now() - - if !l.accessPerUser[userID].limiter.Allow() { + accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{ + limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), + lastAccess: time.Now(), + mutex: &sync.Mutex{}, + }) + accessRecord.(*accessRecords).mutex.Lock() + defer accessRecord.(*accessRecords).mutex.Unlock() + + accessRecord.(*accessRecords).lastAccess = time.Now() + l.accessPerUser.Store(userID, accessRecord) + + if !accessRecord.(*accessRecords).limiter.Allow() { return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) } return nil } +// clean removes the access records for users who have not accessed the system for a while func (l *LimiterStore) clean() { l.mutex.Lock() defer l.mutex.Unlock() - for userID, accessRecord := range l.accessPerUser { - if time.Since(accessRecord.lastAccess) > l.cleanupInterval { - delete(l.accessPerUser, userID) + l.accessPerUser.Range(func(key, value interface{}) bool { + value.(*accessRecords).mutex.Lock() + defer value.(*accessRecords).mutex.Unlock() + if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval { + l.accessPerUser.Delete(key) } - } + return true + }) } func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { l := &LimiterStore{ - accessPerUser: make(map[string]*accessRecords), + accessPerUser: &sync.Map{}, mutex: &sync.Mutex{}, requestPerSec: requestPerSec, burstSize: burstSize, @@ -80,6 +86,7 @@ func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Du return l } +// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware type RateLimiter struct { limiter *LimiterStore } diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go index b06bc5994..cafcf3d00 100644 --- a/plugins/rate_limit_test.go +++ b/plugins/rate_limit_test.go @@ -15,7 +15,7 @@ func TestNewRateLimiter(t *testing.T) { } func TestLimiterAllow(t *testing.T) { - rlStore := newRateLimitStore(1, 1, time.Second) + rlStore := newRateLimitStore(1, 1, 10*time.Second) assert.NoError(t, rlStore.Allow("hello")) assert.Error(t, rlStore.Allow("hello")) time.Sleep(time.Second) @@ -97,13 +97,30 @@ func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { func TestRateLimiterUpdateLastAccessTime(t *testing.T) { rlStore := newRateLimitStore(2, 2, time.Second) assert.NoError(t, rlStore.Allow("hello")) - firstAccessTime := rlStore.accessPerUser["hello"].lastAccess + // get last access time + + accessRecord, _ := rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + firstAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.NoError(t, rlStore.Allow("hello")) - secondAccessTime := rlStore.accessPerUser["hello"].lastAccess + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + secondAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.True(t, secondAccessTime.After(firstAccessTime)) + // Verify that the last access time is updated even when user is rate limited assert.Error(t, rlStore.Allow("hello")) - thirdAccessTime := rlStore.accessPerUser["hello"].lastAccess + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + thirdAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + assert.True(t, thirdAccessTime.After(secondAccessTime)) } From dfaf1836a9da994bcd54ba040fc98ec93cb770f1 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Thu, 11 May 2023 09:29:20 +0200 Subject: [PATCH 8/8] Remove global lock in RateLimitStore Signed-off-by: TungHoang --- plugins/rate_limit.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go index 4c81657de..3496fccc2 100644 --- a/plugins/rate_limit.go +++ b/plugins/rate_limit.go @@ -29,7 +29,6 @@ type LimiterStore struct { accessPerUser *sync.Map requestPerSec int burstSize int - mutex *sync.Mutex cleanupInterval time.Duration } @@ -55,8 +54,6 @@ func (l *LimiterStore) Allow(userID string) error { // clean removes the access records for users who have not accessed the system for a while func (l *LimiterStore) clean() { - l.mutex.Lock() - defer l.mutex.Unlock() l.accessPerUser.Range(func(key, value interface{}) bool { value.(*accessRecords).mutex.Lock() defer value.(*accessRecords).mutex.Unlock() @@ -70,7 +67,6 @@ func (l *LimiterStore) clean() { func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { l := &LimiterStore{ accessPerUser: &sync.Map{}, - mutex: &sync.Mutex{}, requestPerSec: requestPerSec, burstSize: burstSize, cleanupInterval: cleanupInterval,