diff --git a/pkg/agent/protocol/http/proxy.go b/pkg/agent/protocol/http/proxy.go index 29274c58..b6efd080 100644 --- a/pkg/agent/protocol/http/proxy.go +++ b/pkg/agent/protocol/http/proxy.go @@ -74,16 +74,6 @@ func NewProxy(c ProxyConfig, d Disruption) (protocol.Proxy, error) { }, nil } -// contains verifies if a list of strings contains the given string -func contains(list []string, target string) bool { - for _, element := range list { - if element == target { - return true - } - } - return false -} - // httpClient defines the method for executing HTTP requests. It is used to allow mocking // the client in tests type httpClient interface { @@ -97,56 +87,84 @@ type httpHandler struct { client httpClient } -func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - var statusCode int - headers := http.Header{} - body := io.NopCloser(strings.NewReader(h.disruption.ErrorBody)) - - excluded := contains(h.disruption.Excluded, req.URL.Path) - - if !excluded && h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { - // force error code - statusCode = int(h.disruption.ErrorCode) - } else { - req.Host = h.upstreamURL.Host - req.URL.Host = h.upstreamURL.Host - req.URL.Scheme = h.upstreamURL.Scheme - req.RequestURI = "" - originServerResponse, srvErr := h.client.Do(req) - if srvErr != nil { - rw.WriteHeader(http.StatusInternalServerError) - _, _ = fmt.Fprint(rw, srvErr) - return +// isExcluded checks whether a request should be proxied through without any kind of modification whatsoever. +func (h *httpHandler) isExcluded(r *http.Request) bool { + for _, excluded := range h.disruption.Excluded { + if strings.EqualFold(r.URL.Path, excluded) { + return true } + } - headers = originServerResponse.Header - statusCode = originServerResponse.StatusCode - body = originServerResponse.Body + return false +} - defer func() { - _ = originServerResponse.Body.Close() - }() - } +// forward forwards a request to the upstream URL. +// Request is performed immediately, but response won't be sent before the duration specified in delay. +func (h *httpHandler) forward(rw http.ResponseWriter, req *http.Request, delay time.Duration) { + timer := time.After(delay) - if !excluded && h.disruption.AverageDelay > 0 { - delay := int64(h.disruption.AverageDelay) - if h.disruption.DelayVariation > 0 { - variation := int64(h.disruption.DelayVariation) - delay = delay + variation - 2*rand.Int63n(variation) - } - time.Sleep(time.Duration(delay)) + upstreamReq := req.Clone(context.Background()) + upstreamReq.Host = h.upstreamURL.Host + upstreamReq.URL.Host = h.upstreamURL.Host + upstreamReq.URL.Scheme = h.upstreamURL.Scheme + upstreamReq.RequestURI = "" // It is an error to set this field in an HTTP client request. + + response, err := h.client.Do(req) + <-timer + if err != nil { + rw.WriteHeader(http.StatusBadGateway) + _, _ = fmt.Fprint(rw, err) + return } - // return response to the client - for key, values := range headers { + defer func() { + // Fully consume and then close upstream response body. + _, _ = io.Copy(io.Discard, response.Body) + _ = response.Body.Close() + }() + + // Mirror headers. + for key, values := range response.Header { for _, value := range values { rw.Header().Add(key, value) } } - rw.WriteHeader(statusCode) + + // Mirror status code. + rw.WriteHeader(response.StatusCode) // ignore errors writing body, nothing to do. - _, _ = io.Copy(rw, body) + _, _ = io.Copy(rw, response.Body) +} + +// injectError waits sleeps the duration specified in delay and then writes the configured error downstream. +func (h *httpHandler) injectError(rw http.ResponseWriter, delay time.Duration) { + time.Sleep(delay) + + rw.WriteHeader(int(h.disruption.ErrorCode)) + _, _ = rw.Write([]byte(h.disruption.ErrorBody)) +} + +func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if h.isExcluded(req) { + //nolint:contextcheck // Unclear which context the linter requires us to propagate here. + h.forward(rw, req, 0) + return + } + + delay := h.disruption.AverageDelay + if h.disruption.DelayVariation > 0 { + variation := int64(h.disruption.DelayVariation) + delay += time.Duration(variation - 2*rand.Int63n(variation)) + } + + if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { + h.injectError(rw, delay) + return + } + + //nolint:contextcheck // Unclear which context the linter requires us to propagate here. + h.forward(rw, req, delay) } // Start starts the execution of the proxy