Skip to content

Commit

Permalink
mailgunGH-35: Initial proof of http request update feature
Browse files Browse the repository at this point in the history
  • Loading branch information
mgale committed Feb 7, 2022
1 parent ef54c5c commit 411476a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ type HTTPPoolOptions struct {
// receives a request.
// If nil, uses the http.Request.Context()
Context func(*http.Request) context.Context

// UpdateRequest optionally specifies a function that takes in the
// outgoing http request and returns either the same request or a
// new http request.
UpdateRequest func(*http.Request) *http.Request
}

// NewHTTPPool initializes an HTTP pool of peers, and registers itself as a PeerPicker.
Expand Down Expand Up @@ -126,8 +131,9 @@ func (p *HTTPPool) Set(peers ...string) {
p.httpGetters = make(map[string]*httpGetter, len(peers))
for _, peer := range peers {
p.httpGetters[peer] = &httpGetter{
getTransport: p.opts.Transport,
baseURL: peer + p.opts.BasePath,
getTransport: p.opts.Transport,
getUpdateRequest: p.opts.UpdateRequest,
baseURL: peer + p.opts.BasePath,
}
}
}
Expand Down Expand Up @@ -250,8 +256,9 @@ func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

type httpGetter struct {
getTransport func(context.Context) http.RoundTripper
baseURL string
getTransport func(context.Context) http.RoundTripper
getUpdateRequest func(*http.Request) *http.Request
baseURL string
}

func (p *httpGetter) GetURL() string {
Expand Down Expand Up @@ -279,6 +286,10 @@ func (h *httpGetter) makeRequest(ctx context.Context, m string, in request, b io
return err
}

if h.getUpdateRequest != nil {
req = h.getUpdateRequest(req)
}

tr := http.DefaultTransport
if h.getTransport != nil {
tr = h.getTransport(ctx)
Expand Down
2 changes: 1 addition & 1 deletion http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func TestHTTPPool(t *testing.T) {
}

if !bytes.Equal(setValue, getValue.ByteSlice()) {
t.Fatal(errors.New(fmt.Sprintf("incorrect value retrieved after set: %s", getValue)))
t.Fatal(fmt.Errorf("incorrect value retrieved after set: %s", getValue))
}
}

Expand Down

0 comments on commit 411476a

Please sign in to comment.