diff --git a/pkg/agent/protocol/http/proxy.go b/pkg/agent/protocol/http/proxy.go index 006fc7c0..1cb5bbbd 100644 --- a/pkg/agent/protocol/http/proxy.go +++ b/pkg/agent/protocol/http/proxy.go @@ -99,77 +99,100 @@ func (h *httpHandler) isExcluded(r *http.Request) bool { return false } -// forward forwards a request to the upstream URL. -// Request is performed immediately, but response won't be sent before the duration specified in delay. -func (h *httpHandler) forward(rw http.ResponseWriter, req *http.Request, delay time.Duration) { - timer := time.After(delay) - +// forward forwards a request to the upstream URL and returns a function that +// copies the response to a ResponseWriter +func (h *httpHandler) forward(req *http.Request) func(rw http.ResponseWriter) { upstreamReq := req.Clone(context.Background()) upstreamReq.Host = h.upstreamURL.Host upstreamReq.URL.Host = h.upstreamURL.Host upstreamReq.URL.Scheme = h.upstreamURL.Scheme upstreamReq.RequestURI = "" // It is an error to set this field in an HTTP client request. + //nolint:bodyclose // it is closed in the returned functions response, err := http.DefaultClient.Do(upstreamReq) - <-timer - if err != nil { - rw.WriteHeader(http.StatusBadGateway) - _, _ = fmt.Fprint(rw, err) - return - } - defer func() { - // Fully consume and then close upstream response body. - _, _ = io.Copy(io.Discard, response.Body) - _ = response.Body.Close() - }() + // return a function that writes the upstream error + if err != nil { + return func(rw http.ResponseWriter) { + rw.WriteHeader(http.StatusBadGateway) + _, _ = fmt.Fprint(rw, err) - // Mirror headers. - for key, values := range response.Header { - for _, value := range values { - rw.Header().Add(key, value) + // Fully consume and then close upstream response body. + _, _ = io.Copy(io.Discard, response.Body) + _ = response.Body.Close() } } - // Mirror status code. - rw.WriteHeader(response.StatusCode) + // return a function that copies upstream response + return func(rw http.ResponseWriter) { + // Mirror headers. + for key, values := range response.Header { + for _, value := range values { + rw.Header().Add(key, value) + } + } + + // Mirror status code. + rw.WriteHeader(response.StatusCode) - // ignore errors writing body, nothing to do. - _, _ = io.Copy(rw, response.Body) + // ignore errors writing body, nothing to do. + _, _ = io.Copy(rw, response.Body) + _ = response.Body.Close() + } } -// injectError waits sleeps the duration specified in delay and then writes the configured error downstream. -func (h *httpHandler) injectError(rw http.ResponseWriter, delay time.Duration) { - time.Sleep(delay) - +// injectError writes the configured error to a ResponseWriter +func (h *httpHandler) injectError(rw http.ResponseWriter) { rw.WriteHeader(int(h.disruption.ErrorCode)) _, _ = rw.Write([]byte(h.disruption.ErrorBody)) } + +func (h *httpHandler) delay() time.Duration { + delay := h.disruption.AverageDelay + if h.disruption.DelayVariation > 0 { + variation := int64(h.disruption.DelayVariation) + delay += time.Duration(variation - 2*rand.Int63n(variation)) + } + + return delay +} + func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { h.metrics.Inc(protocol.MetricRequests) + // if excluded, forward request and return response immediately if h.isExcluded(req) { h.metrics.Inc(protocol.MetricRequestsExcluded) //nolint:contextcheck // Unclear which context the linter requires us to propagate here. - h.forward(rw, req, 0) + h.forward(req)(rw) return } - delay := h.disruption.AverageDelay - if h.disruption.DelayVariation > 0 { - variation := int64(h.disruption.DelayVariation) - delay += time.Duration(variation - 2*rand.Int63n(variation)) - } + // writer is used to write the response + var writer func(rw http.ResponseWriter) + + // forward request + done := make(chan struct{}) + go func() { + if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { + h.metrics.Inc(protocol.MetricRequestsDisrupted) + writer = h.injectError + } else { + writer = h.forward(req) + } - if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { - h.metrics.Inc(protocol.MetricRequestsDisrupted) - h.injectError(rw, delay) - return - } + done <- struct{}{} + }() + + // wait for delay + <-time.After(h.delay()) + + // wait for upstream request + <-done - //nolint:contextcheck // Unclear which context the linter requires us to propagate here. - h.forward(rw, req, delay) + // return response + writer(rw) } // Start starts the execution of the proxy