diff --git a/utils/rate.go b/utils/rate.go index 13fe73d9..266dced7 100644 --- a/utils/rate.go +++ b/utils/rate.go @@ -31,28 +31,44 @@ import ( ) type LeakyBucket struct { - mutex sync.Mutex - last time.Time - sleepFor time.Duration - perRequest atomic.Duration + mutex sync.Mutex + last time.Time + sleepFor time.Duration + cfg atomic.Pointer[leakyBucketConfig] + clock Clock +} + +type leakyBucketConfig struct { + perRequest time.Duration maxSlack time.Duration - clock Clock } -func NewLeakyBucket(rateLimit int, slack time.Duration, clock Clock) *LeakyBucket { +// NewLeakyBucket initiates LeakyBucket with rateLimit, slack, and clock. +// +// rateLimit is defined as the number of request per second. +// +// slack is defined as the number of allowed requests before limiting. +// e.g. when slack=5, LeakyBucket will allow 5 requests to pass through Take +// without a sleep as long as these requests are under perRequest duration. +func NewLeakyBucket(rateLimit int, slack int, clock Clock) *LeakyBucket { var lb LeakyBucket - lb.SetRateLimit(rateLimit) - lb.maxSlack = -1 * time.Duration(slack) * lb.perRequest.Load() lb.clock = clock + lb.Update(rateLimit, slack) return &lb } -// SetRateLimit sets the underlying rate limit. +// Update sets the underlying rate limit and slack. // The setting may not be applied immediately. // -// SetRateLimit is THREAD SAFE and NON-BLOCKING. -func (lb *LeakyBucket) SetRateLimit(rateLimit int) { - lb.perRequest.Store(time.Second / time.Duration(rateLimit)) +// Update is THREAD SAFE and NON-BLOCKING. +func (lb *LeakyBucket) Update(rateLimit int, slack int) { + perRequest := time.Second / time.Duration(rateLimit) + maxSlack := -1 * time.Duration(slack) * perRequest + cfg := leakyBucketConfig{ + perRequest: perRequest, + maxSlack: maxSlack, + } + lb.cfg.Store(&cfg) } // Take blocks to ensure that the time spent between multiple Take calls @@ -63,6 +79,7 @@ func (lb *LeakyBucket) Take() time.Time { lb.mutex.Lock() defer lb.mutex.Unlock() + cfg := lb.cfg.Load() now := lb.clock.Now() // If this is our first request, then we allow it. @@ -75,13 +92,13 @@ func (lb *LeakyBucket) Take() time.Time { // the perRequest budget and how long the last request took. // Since the request may take longer than the budget, this number // can get negative, and is summed across requests. - lb.sleepFor += lb.perRequest.Load() - now.Sub(lb.last) + lb.sleepFor += cfg.perRequest - now.Sub(lb.last) // We shouldn't allow sleepFor to get too negative, since it would mean that // a service that slowed down a lot for a short period of time would get // a much higher RPS following that. - if lb.sleepFor < lb.maxSlack { - lb.sleepFor = lb.maxSlack + if lb.sleepFor < cfg.maxSlack { + lb.sleepFor = cfg.maxSlack } // If sleepFor is positive, then we should sleep now. diff --git a/utils/rate_test.go b/utils/rate_test.go index 9d2125c5..961234c0 100644 --- a/utils/rate_test.go +++ b/utils/rate_test.go @@ -32,6 +32,7 @@ import ( "github.com/benbjohnson/clock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const UnstableTest = "UNSTABLE TEST" @@ -161,11 +162,14 @@ func runTest(t *testing.T, fn func(testRunner)) { constructor: func(rate int, opts ...Option) Limiter { config := buildConfig(opts) perRequest := config.per / time.Duration(rate) + cfg := leakyBucketConfig{ + perRequest: perRequest, + maxSlack: -1 * time.Duration(config.slack) * perRequest, + } l := &LeakyBucket{ - maxSlack: -1 * time.Duration(config.slack) * perRequest, - clock: config.clock, + clock: config.clock, } - l.perRequest.Store(perRequest) + l.cfg.Store(&cfg) return l }, }, @@ -495,7 +499,7 @@ func TestSetRateLimitOnTheFly(t *testing.T) { // Set rate to 1hz limiter, ok := r.createLimiter(1, WithoutSlack).(*LeakyBucket) if !ok { - t.Skip("SetRateLimit is not supported") + t.Skip("Update is not supported") } r.startTaking(limiter) @@ -505,17 +509,23 @@ func TestSetRateLimitOnTheFly(t *testing.T) { r.assertCountAt(time.Second, 3) // increase to 2hz - limiter.SetRateLimit(2) + limiter.Update(2, 0) r.getClock().Add(time.Second) r.assertCountAt(time.Second, 4) // <- delayed due to paying sleepFor debt r.getClock().Add(time.Second) r.assertCountAt(time.Second, 6) // reduce to 1hz again - limiter.SetRateLimit(1) + limiter.Update(1, 0) r.getClock().Add(time.Second) r.assertCountAt(time.Second, 7) r.getClock().Add(time.Second) r.assertCountAt(time.Second, 8) + + slack := 3 + require.GreaterOrEqual(t, limiter.sleepFor, time.Duration(0)) + limiter.Update(1, slack) + r.getClock().Add(time.Second * time.Duration(slack)) + r.assertCountAt(time.Second, 8+slack) }) }