From 411476abeb8f395a2edb85297f91eb084cf8f176 Mon Sep 17 00:00:00 2001 From: Michael Gale Date: Sun, 6 Feb 2022 19:09:09 -0700 Subject: [PATCH] GH-35: Initial proof of http request update feature --- http.go | 19 +++++++++++++++---- http_test.go | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/http.go b/http.go index b1dc10a..89ffa82 100644 --- a/http.go +++ b/http.go @@ -73,6 +73,11 @@ type HTTPPoolOptions struct { // receives a request. // If nil, uses the http.Request.Context() Context func(*http.Request) context.Context + + // UpdateRequest optionally specifies a function that takes in the + // outgoing http request and returns either the same request or a + // new http request. + UpdateRequest func(*http.Request) *http.Request } // NewHTTPPool initializes an HTTP pool of peers, and registers itself as a PeerPicker. @@ -126,8 +131,9 @@ func (p *HTTPPool) Set(peers ...string) { p.httpGetters = make(map[string]*httpGetter, len(peers)) for _, peer := range peers { p.httpGetters[peer] = &httpGetter{ - getTransport: p.opts.Transport, - baseURL: peer + p.opts.BasePath, + getTransport: p.opts.Transport, + getUpdateRequest: p.opts.UpdateRequest, + baseURL: peer + p.opts.BasePath, } } } @@ -250,8 +256,9 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { } type httpGetter struct { - getTransport func(context.Context) http.RoundTripper - baseURL string + getTransport func(context.Context) http.RoundTripper + getUpdateRequest func(*http.Request) *http.Request + baseURL string } func (p *httpGetter) GetURL() string { @@ -279,6 +286,10 @@ func (h *httpGetter) makeRequest(ctx context.Context, m string, in request, b io return err } + if h.getUpdateRequest != nil { + req = h.getUpdateRequest(req) + } + tr := http.DefaultTransport if h.getTransport != nil { tr = h.getTransport(ctx) diff --git a/http_test.go b/http_test.go index 86b6d34..ee18623 100644 --- a/http_test.go +++ b/http_test.go @@ -172,7 +172,7 @@ func TestHTTPPool(t *testing.T) { } if !bytes.Equal(setValue, getValue.ByteSlice()) { - t.Fatal(errors.New(fmt.Sprintf("incorrect value retrieved after set: %s", getValue))) + t.Fatal(fmt.Errorf("incorrect value retrieved after set: %s", getValue)) } }