diff --git a/pkg/config/config.go b/pkg/config/config.go index 28996bd23..d86c07333 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -80,7 +80,8 @@ type ServerSecurityOptions struct { // These are the Access-Control-Request-Headers that the server will respond to. // By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type. // DeprecatedUser this setting to add any additional headers which are needed - AllowedHeaders []string `json:"allowedHeaders"` + AllowedHeaders []string `json:"allowedHeaders"` + RateLimit RateLimitOptions `json:"rateLimit"` } type SslOptions struct { @@ -88,6 +89,14 @@ type SslOptions struct { KeyFile string `json:"keyFile"` } +// RateLimitOptions is a type to hold rate limit configuration options. +type RateLimitOptions struct { + Enabled bool `json:"enabled" pflag:",Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm."` + RequestsPerSecond int `json:"requestsPerSecond" pflag:",The number of requests allowed per second."` + BurstSize int `json:"burstSize" pflag:",The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens."` + CleanupInterval config.Duration `json:"cleanupInterval" pflag:",The interval at which the rate limiter cleans up entries that have not been used for a certain period of time."` +} + var defaultServerConfig = &ServerConfig{ HTTPPort: 8088, Security: ServerSecurityOptions{ diff --git a/pkg/config/serverconfig_flags.go b/pkg/config/serverconfig_flags.go index c37a82603..fe7e00e00 100755 --- a/pkg/config/serverconfig_flags.go +++ b/pkg/config/serverconfig_flags.go @@ -63,6 +63,10 @@ func (cfg ServerConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.allowCors"), defaultServerConfig.Security.AllowCors, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedOrigins"), defaultServerConfig.Security.AllowedOrigins, "") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "security.allowedHeaders"), defaultServerConfig.Security.AllowedHeaders, "") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "security.rateLimit.enabled"), defaultServerConfig.Security.RateLimit.Enabled, "Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.requestsPerSecond"), defaultServerConfig.Security.RateLimit.RequestsPerSecond, "The number of requests allowed per second.") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "security.rateLimit.burstSize"), defaultServerConfig.Security.RateLimit.BurstSize, "The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "security.rateLimit.cleanupInterval"), defaultServerConfig.Security.RateLimit.CleanupInterval.String(), "The interval at which the rate limiter cleans up entries that have not been used for a certain period of time.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.port"), defaultServerConfig.GrpcConfig.Port, "On which grpc port to serve admin") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "grpc.serverReflection"), defaultServerConfig.GrpcConfig.ServerReflection, "Enable GRPC Server Reflection") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "grpc.maxMessageSizeBytes"), defaultServerConfig.GrpcConfig.MaxMessageSizeBytes, "The max size in bytes for incoming gRPC messages") diff --git a/pkg/config/serverconfig_flags_test.go b/pkg/config/serverconfig_flags_test.go index b16e0416d..f5eaa0752 100755 --- a/pkg/config/serverconfig_flags_test.go +++ b/pkg/config/serverconfig_flags_test.go @@ -281,6 +281,62 @@ func TestServerConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_security.rateLimit.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.enabled", testValue) + if vBool, err := cmdFlags.GetBool("security.rateLimit.enabled"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vBool), &actual.Security.RateLimit.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.requestsPerSecond", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.requestsPerSecond", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.requestsPerSecond"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.RequestsPerSecond) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.burstSize", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("security.rateLimit.burstSize", testValue) + if vInt, err := cmdFlags.GetInt("security.rateLimit.burstSize"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vInt), &actual.Security.RateLimit.BurstSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_security.rateLimit.cleanupInterval", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := defaultServerConfig.Security.RateLimit.CleanupInterval.String() + + cmdFlags.Set("security.rateLimit.cleanupInterval", testValue) + if vString, err := cmdFlags.GetString("security.rateLimit.cleanupInterval"); err == nil { + testDecodeJson_ServerConfig(t, fmt.Sprintf("%v", vString), &actual.Security.RateLimit.CleanupInterval) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_grpc.port", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/pkg/server/service.go b/pkg/server/service.go index f3b27416f..787eea209 100644 --- a/pkg/server/service.go +++ b/pkg/server/service.go @@ -90,9 +90,17 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c auth.AuthenticationLoggingInterceptor, middlewareInterceptors, ) + if cfg.Security.RateLimit.Enabled { + rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration) + rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter) + chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors) + } } else { logger.Infof(ctx, "Creating gRPC server without authentication") chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor) + if cfg.Security.RateLimit.Enabled { + logger.Warningf(ctx, "Rate limit is enabled but auth is not") + } } serverOpts := []grpc.ServerOption{ @@ -257,6 +265,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry, } oauth2ResourceServer = oauth2Provider + } else { oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL) if err != nil { diff --git a/plugins/rate_limit.go b/plugins/rate_limit.go new file mode 100644 index 000000000..3496fccc2 --- /dev/null +++ b/plugins/rate_limit.go @@ -0,0 +1,116 @@ +package plugins + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + auth "github.com/flyteorg/flyteadmin/auth" + "golang.org/x/time/rate" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type RateLimitExceeded error + +// accessRecords stores the rate limiter and the last access time +type accessRecords struct { + limiter *rate.Limiter + lastAccess time.Time + mutex *sync.Mutex +} + +// LimiterStore stores the access records for each user +type LimiterStore struct { + // accessPerUser is a synchronized map of userID to accessRecords + accessPerUser *sync.Map + requestPerSec int + burstSize int + cleanupInterval time.Duration +} + +// Allow takes a userID and returns an error if the user has exceeded the rate limit +func (l *LimiterStore) Allow(userID string) error { + accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{ + limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize), + lastAccess: time.Now(), + mutex: &sync.Mutex{}, + }) + accessRecord.(*accessRecords).mutex.Lock() + defer accessRecord.(*accessRecords).mutex.Unlock() + + accessRecord.(*accessRecords).lastAccess = time.Now() + l.accessPerUser.Store(userID, accessRecord) + + if !accessRecord.(*accessRecords).limiter.Allow() { + return RateLimitExceeded(fmt.Errorf("rate limit exceeded")) + } + + return nil +} + +// clean removes the access records for users who have not accessed the system for a while +func (l *LimiterStore) clean() { + l.accessPerUser.Range(func(key, value interface{}) bool { + value.(*accessRecords).mutex.Lock() + defer value.(*accessRecords).mutex.Unlock() + if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval { + l.accessPerUser.Delete(key) + } + return true + }) +} + +func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore { + l := &LimiterStore{ + accessPerUser: &sync.Map{}, + requestPerSec: requestPerSec, + burstSize: burstSize, + cleanupInterval: cleanupInterval, + } + + go func() { + for { + time.Sleep(l.cleanupInterval) + l.clean() + } + }() + + return l +} + +// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware +type RateLimiter struct { + limiter *LimiterStore +} + +func (r *RateLimiter) Limit(ctx context.Context) error { + IdenCtx := auth.IdentityContextFromContext(ctx) + if IdenCtx.IsEmpty() { + return errors.New("no identity context found") + } + userID := IdenCtx.UserID() + if err := r.limiter.Allow(userID); err != nil { + return err + } + return nil +} + +func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter { + limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval) + return &RateLimiter{limiter: limiter} +} + +func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) ( + resp interface{}, err error) { + if err := limiter.Limit(ctx); err != nil { + return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded") + } + + return handler(ctx, req) + } +} diff --git a/plugins/rate_limit_test.go b/plugins/rate_limit_test.go new file mode 100644 index 000000000..cafcf3d00 --- /dev/null +++ b/plugins/rate_limit_test.go @@ -0,0 +1,126 @@ +package plugins + +import ( + "context" + "testing" + "time" + + auth "github.com/flyteorg/flyteadmin/auth" + "github.com/stretchr/testify/assert" +) + +func TestNewRateLimiter(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NotNil(t, rlStore) +} + +func TestLimiterAllow(t *testing.T) { + rlStore := newRateLimitStore(1, 1, 10*time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) + time.Sleep(time.Second) + assert.NoError(t, rlStore.Allow("hello")) +} + +func TestLimiterAllowBurst(t *testing.T) { + rlStore := newRateLimitStore(1, 2, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) + assert.NoError(t, rlStore.Allow("world")) +} + +func TestLimiterClean(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + assert.Error(t, rlStore.Allow("hello")) + time.Sleep(time.Second) + rlStore.clean() + assert.NoError(t, rlStore.Allow("hello")) +} + +func TestLimiterAllowOnMultipleRequests(t *testing.T) { + rlStore := newRateLimitStore(1, 1, time.Second) + assert.NoError(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) + assert.Error(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("b")) + + time.Sleep(time.Second) + + assert.NoError(t, rlStore.Allow("a")) + assert.Error(t, rlStore.Allow("a")) + assert.NoError(t, rlStore.Allow("b")) + assert.Error(t, rlStore.Allow("b")) + assert.NoError(t, rlStore.Allow("c")) +} + +func TestRateLimiterLimitPass(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + +} + +func TestRateLimiterLimitStop(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil) + assert.NoError(t, err) + ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx) + err = rateLimit.Limit(ctx) + assert.NoError(t, err) + + err = rateLimit.Limit(ctx) + assert.Error(t, err) + +} + +func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) { + rateLimit := NewRateLimiter(1, 1, time.Second) + assert.NotNil(t, rateLimit) + + ctx := context.TODO() + + err := rateLimit.Limit(ctx) + assert.Error(t, err) +} + +func TestRateLimiterUpdateLastAccessTime(t *testing.T) { + rlStore := newRateLimitStore(2, 2, time.Second) + assert.NoError(t, rlStore.Allow("hello")) + // get last access time + + accessRecord, _ := rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + firstAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + + assert.NoError(t, rlStore.Allow("hello")) + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + secondAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + + assert.True(t, secondAccessTime.After(firstAccessTime)) + + // Verify that the last access time is updated even when user is rate limited + assert.Error(t, rlStore.Allow("hello")) + + accessRecord, _ = rlStore.accessPerUser.Load("hello") + accessRecord.(*accessRecords).mutex.Lock() + thirdAccessTime := accessRecord.(*accessRecords).lastAccess + accessRecord.(*accessRecords).mutex.Unlock() + + assert.True(t, thirdAccessTime.After(secondAccessTime)) + +}