diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c7357d8..db391755 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix ISM Transition to omitempty Conditions field ([#609](https://github.com/opensearch-project/opensearch-go/pull/609)) - Fix ISM Allocation field types ([#609](https://github.com/opensearch-project/opensearch-go/pull/609)) - Fix ISM Error Notification types ([#612](https://github.com/opensearch-project/opensearch-go/pull/612)) +- Fix signer receiving drained body on retries ([#620](https://github.com/opensearch-project/opensearch-go/pull/620)) ### Security diff --git a/opensearchtransport/opensearchtransport.go b/opensearchtransport/opensearchtransport.go index 8e492596..82571435 100644 --- a/opensearchtransport/opensearchtransport.go +++ b/opensearchtransport/opensearchtransport.go @@ -292,10 +292,6 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) { c.setReqURL(conn.URL, req) c.setReqAuth(conn.URL, req) - if err = c.signRequest(req); err != nil { - return nil, fmt.Errorf("failed to sign request: %w", err) - } - if !c.disableRetry && i > 0 && req.Body != nil && req.Body != http.NoBody { body, err := req.GetBody() if err != nil { @@ -304,6 +300,10 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) { req.Body = body } + if err = c.signRequest(req); err != nil { + return nil, fmt.Errorf("failed to sign request: %w", err) + } + // Set up time measures and execute the request start := time.Now().UTC() res, err = c.transport.RoundTrip(req) diff --git a/opensearchtransport/opensearchtransport_internal_test.go b/opensearchtransport/opensearchtransport_internal_test.go index e04a0ab5..9e325318 100644 --- a/opensearchtransport/opensearchtransport_internal_test.go +++ b/opensearchtransport/opensearchtransport_internal_test.go @@ -64,9 +64,13 @@ type mockSigner struct { SampleKey string SampleValue string ReturnError bool + testHook func(*http.Request) } func (m *mockSigner) SignRequest(req *http.Request) error { + if m.testHook != nil { + m.testHook(req) + } if m.ReturnError { return fmt.Errorf("invalid data") } @@ -732,6 +736,47 @@ func TestTransportPerformRetries(t *testing.T) { } }) + t.Run("Signer can sign correctly during retry", func(t *testing.T) { + u, _ := url.Parse("https://foo.com/bar") + signer := mockSigner{} + callsToSigner := 0 + expectedBody := "FOOBAR" + + signer.testHook = func(req *http.Request) { + callsToSigner++ + body, err := io.ReadAll(req.Body) + if err != nil { + panic(err) + } + if string(body) != expectedBody { + t.Fatalf("request %d body: expected %q, got %q", callsToSigner, expectedBody, body) + } + } + + tp, _ := New( + Config{ + URLs: []*url.URL{u}, + Signer: &signer, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + return &http.Response{Status: "MOCK", StatusCode: http.StatusBadGateway}, nil + }, + }, + }, + ) + + req, _ := http.NewRequest(http.MethodPost, "/abc", strings.NewReader(expectedBody)) + //nolint:bodyclose // Mock response does not have a body to close + _, err := tp.Perform(req) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if callsToSigner != 4 { + t.Fatalf("expected 4 requests, got %d", callsToSigner) + } + }) + t.Run("Don't retry request on regular error", func(t *testing.T) { var i int