diff --git a/internal/hammer/hammer.go b/internal/hammer/hammer.go index a0abb3e5..3b50c47d 100644 --- a/internal/hammer/hammer.go +++ b/internal/hammer/hammer.go @@ -408,13 +408,16 @@ func NewThrottle(opsPerSecond int) *Throttle { } type Throttle struct { - opsPerSecond int tokenChan chan bool + mu sync.Mutex + opsPerSecond int oversupply int } func (t *Throttle) Increase() { + t.mu.Lock() + defer t.mu.Unlock() tokenCount := t.opsPerSecond delta := float64(tokenCount) * 0.1 if delta < 1 { @@ -424,6 +427,8 @@ func (t *Throttle) Increase() { } func (t *Throttle) Decrease() { + t.mu.Lock() + defer t.mu.Unlock() tokenCount := t.opsPerSecond if tokenCount <= 1 { return @@ -443,20 +448,27 @@ func (t *Throttle) Run(ctx context.Context) { case <-ctx.Done(): //context cancelled return case <-ticker.C: - tokenCount := t.opsPerSecond - timeout := time.After(interval) - Loop: - for i := 0; i < t.opsPerSecond; i++ { - select { - case t.tokenChan <- true: - tokenCount-- - case <-timeout: - break Loop - } - } + ctx, cancel := context.WithTimeout(ctx, interval) + t.supplyTokens(ctx) + cancel() + } + } +} + +func (t *Throttle) supplyTokens(ctx context.Context) { + t.mu.Lock() + defer t.mu.Unlock() + tokenCount := t.opsPerSecond + for i := 0; i < t.opsPerSecond; i++ { + select { + case t.tokenChan <- true: + tokenCount-- + case <-ctx.Done(): t.oversupply = tokenCount + return } } + t.oversupply = 0 } func (t *Throttle) String() string {