diff --git a/middleware/throttle.go b/middleware/throttle.go index bdf4f9f1..9a870d88 100644 --- a/middleware/throttle.go +++ b/middleware/throttle.go @@ -22,11 +22,12 @@ type ThrottleOpts struct { Limit int BacklogLimit int BacklogTimeout time.Duration + StatusCode int } // Throttle is a middleware that limits number of currently processed requests // at a time across all users. Note: Throttle is not a rate-limiter per user, -// instead it just puts a ceiling on the number of currently in-flight requests +// instead it just puts a ceiling on the number of current in-flight requests // being processed from the point from where the Throttle middleware is mounted. func Throttle(limit int) func(http.Handler) http.Handler { return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout}) @@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { panic("chi/middleware: Throttle expects backlogLimit to be positive") } + statusCode := opts.StatusCode + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + t := throttler{ tokens: make(chan token, opts.Limit), backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit), backlogTimeout: opts.BacklogTimeout, + statusCode: statusCode, retryAfterFn: opts.RetryAfterFn, } @@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { case <-ctx.Done(): t.setRetryAfterHeaderIfNeeded(w, true) - http.Error(w, errContextCanceled, http.StatusTooManyRequests) + http.Error(w, errContextCanceled, t.statusCode) return case btok := <-t.backlogTokens: @@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { select { case <-timer.C: t.setRetryAfterHeaderIfNeeded(w, false) - http.Error(w, errTimedOut, http.StatusTooManyRequests) + http.Error(w, errTimedOut, t.statusCode) return case <-ctx.Done(): timer.Stop() t.setRetryAfterHeaderIfNeeded(w, true) - http.Error(w, errContextCanceled, http.StatusTooManyRequests) + http.Error(w, errContextCanceled, t.statusCode) return case tok := <-t.tokens: defer func() { @@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler { default: t.setRetryAfterHeaderIfNeeded(w, false) - http.Error(w, errCapacityExceeded, http.StatusTooManyRequests) + http.Error(w, errCapacityExceeded, t.statusCode) return } } @@ -119,8 +126,9 @@ type token struct{} type throttler struct { tokens chan token backlogTokens chan token - retryAfterFn func(ctxDone bool) time.Duration backlogTimeout time.Duration + statusCode int + retryAfterFn func(ctxDone bool) time.Duration } // setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized. diff --git a/middleware/throttle_test.go b/middleware/throttle_test.go index 8ed7ff18..d4855f45 100644 --- a/middleware/throttle_test.go +++ b/middleware/throttle_test.go @@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) { res, err := client.Get(server.URL) assertNoError(t, err) assertEqual(t, http.StatusOK, res.StatusCode) - }(i) } @@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) { assertNoError(t, err) assertEqual(t, http.StatusTooManyRequests, res.StatusCode) assertEqual(t, errTimedOut, strings.TrimSpace(string(buf))) - }(i) } @@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) { buf, err := ioutil.ReadAll(res.Body) assertNoError(t, err) assertEqual(t, testContent, buf) - }(i) } @@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) { assertNoError(t, err) assertEqual(t, http.StatusTooManyRequests, res.StatusCode) assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf))) - }(i) } @@ -252,3 +248,54 @@ func TestThrottleMaximum(t *testing.T) { wg.Wait() }*/ + +func TestThrottleCustomStatusCode(t *testing.T) { + const timeout = time.Second * 3 + + wait := make(chan struct{}) + + r := chi.NewRouter() + r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable})) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + select { + case <-wait: + case <-time.After(timeout): + } + w.WriteHeader(http.StatusOK) + }) + server := httptest.NewServer(r) + defer server.Close() + + const totalRequestCount = 5 + + codes := make(chan int, totalRequestCount) + errs := make(chan error, totalRequestCount) + client := &http.Client{Timeout: timeout} + for i := 0; i < totalRequestCount; i++ { + go func() { + resp, err := client.Get(server.URL) + if err != nil { + errs <- err + return + } + codes <- resp.StatusCode + }() + } + + waitResponse := func(wantCode int) { + select { + case err := <-errs: + t.Fatal(err) + case code := <-codes: + assertEqual(t, wantCode, code) + case <-time.After(timeout): + t.Fatalf("waiting %d code, timeout exceeded", wantCode) + } + } + + for i := 0; i < totalRequestCount-1; i++ { + waitResponse(http.StatusServiceUnavailable) + } + close(wait) // Allow the last request to proceed. + waitResponse(http.StatusOK) +}