Skip to content

Commit

Permalink
Fix copying of request body: (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Sep 19, 2023
2 parents c402b01 + f6a3401 commit b9b7ef8
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 60 deletions.
2 changes: 2 additions & 0 deletions examples/rpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ func testConsumer(ctx context.Context) error {

case rpc.BootDeviceMethod:

case rpc.PingMethod:
rp.Result = "pong"
default:
w.WriteHeader(http.StatusNotFound)
}
Expand Down
29 changes: 3 additions & 26 deletions providers/rpc/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,39 +48,16 @@ func TestRequestKVS(t *testing.T) {
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
kvs := requestKVS(tc.req)
buf := new(bytes.Buffer)
_, _ = io.Copy(buf, tc.req.Body)
kvs := requestKVS(tc.req.Method, tc.req.URL.String(), tc.req.Header, buf)
if diff := cmp.Diff(kvs, tc.expected); diff != "" {
t.Fatalf("requestKVS() mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestRequestKVSOneOffs(t *testing.T) {
t.Run("nil body", func(t *testing.T) {
req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", nil)
got := requestKVS(req)
if diff := cmp.Diff(got, []interface{}{"request", requestDetails{}}); diff != "" {
t.Logf("got: %+v", got)
t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff)
}
})
t.Run("nil request", func(t *testing.T) {
if diff := cmp.Diff(requestKVS(nil), []interface{}{"request", requestDetails{}}); diff != "" {
t.Fatalf("requestKVS(nil) mismatch (-want +got):\n%s", diff)
}
})

t.Run("failed to unmarshal body", func(t *testing.T) {
req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", bytes.NewBufferString("invalid"))
got := requestKVS(req)
if diff := cmp.Diff(got, []interface{}{"request", requestDetails{URL: "http://example.com", Method: http.MethodPost, Headers: http.Header{}}}); diff != "" {
t.Logf("got: %+v", got)
t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff)
}
})
}

func TestResponseKVS(t *testing.T) {
tests := map[string]struct {
resp *http.Response
Expand Down
18 changes: 7 additions & 11 deletions providers/rpc/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rpc
import (
"bytes"
"encoding/json"
"io"
"net/http"
)

Expand All @@ -21,20 +20,17 @@ type responseDetails struct {
}

// requestKVS returns a slice of key, value sets. Used for logging.
func requestKVS(req *http.Request) []interface{} {
func requestKVS(method, url string, headers http.Header, body *bytes.Buffer) []interface{} {
var r requestDetails
if req != nil && req.Body != nil {
if body.Len() > 0 {
var p RequestPayload
reqBody, err := io.ReadAll(req.Body)
if err == nil {
req.Body = io.NopCloser(bytes.NewBuffer(reqBody))
_ = json.Unmarshal(reqBody, &p)
}
_ = json.Unmarshal(body.Bytes(), &p)

r = requestDetails{
Body: p,
Headers: req.Header,
URL: req.URL.String(),
Method: req.Method,
Headers: headers,
URL: url,
Method: method,
}
}

Expand Down
1 change: 1 addition & 0 deletions providers/rpc/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
PowerSetMethod Method = "setPowerState"
PowerGetMethod Method = "getPowerState"
VirtualMediaMethod Method = "setVirtualMedia"
PingMethod Method = "ping"
)

// RequestPayload is the payload sent to the ConsumerURL.
Expand Down
42 changes: 20 additions & 22 deletions providers/rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rpc
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"hash"
Expand Down Expand Up @@ -182,27 +181,16 @@ func (p *Provider) Open(ctx context.Context) error {
return err
}
p.listenerURL = u
buf := new(bytes.Buffer)
_ = json.NewEncoder(buf).Encode(RequestPayload{})
testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), buf)
if err != nil {
return err
}
// test that we can communicate with the rpc consumer.
resp, err := p.Client.Do(testReq)
if err != nil {
return err
}
if resp.StatusCode >= http.StatusInternalServerError {
return fmt.Errorf("issue on the rpc consumer side, status code: %d", resp.StatusCode)
}

// test that the consumer responses with the expected contract (ResponsePayload{}).
if err := json.NewDecoder(resp.Body).Decode(&ResponsePayload{}); err != nil {
return fmt.Errorf("issue with the rpc consumer response: %v", err)
if _, err = p.process(ctx, RequestPayload{
ID: time.Now().UnixNano(),
Host: p.Host,
Method: PingMethod,
}); err != nil {
return err
}

return resp.Body.Close()
return nil
}

// Close a connection to the rpc consumer.
Expand Down Expand Up @@ -274,7 +262,12 @@ func (p *Provider) PowerStateGet(ctx context.Context) (state string, err error)
return "", fmt.Errorf("error from rpc consumer: %v", resp.Error)
}

return resp.Result.(string), nil
s, ok := resp.Result.(string)
if !ok {
return "", fmt.Errorf("expected result equal to type string, got: %T", resp.Result)
}

return s, nil
}

// process is the main function for the roundtrip of rpc calls to the ConsumerURL.
Expand All @@ -292,9 +285,14 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl

// create the signature payload
reqBuf := new(bytes.Buffer)
if _, err := io.Copy(reqBuf, req.Body); err != nil {
reqBody, err := req.GetBody()
if err != nil {
return ResponsePayload{}, fmt.Errorf("failed to get request body: %w", err)
}
if _, err := io.Copy(reqBuf, reqBody); err != nil {
return ResponsePayload{}, fmt.Errorf("failed to read request body: %w", err)
}

headersForSig := http.Header{}
for _, h := range p.Opts.Signature.IncludedPayloadHeaders {
if val := req.Header.Get(h); val != "" {
Expand All @@ -321,7 +319,7 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl
}

// request/response round trip.
kvs := requestKVS(req)
kvs := requestKVS(req.Method, req.URL.String(), req.Header, reqBuf)
kvs = append(kvs, []interface{}{"host", p.Host, "method", rp.Method, "consumerURL", p.ConsumerURL}...)
if rp.Params != nil {
kvs = append(kvs, []interface{}{"params", rp.Params}...)
Expand Down
2 changes: 1 addition & 1 deletion providers/rpc/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestPowerStateGet(t *testing.T) {
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
rsp := testConsumer{
rp: ResponsePayload{Result: tc.powerState},
rp: ResponsePayload{ID: 123, Host: "127.0.1.1", Result: tc.powerState},
}
if tc.shouldErr {
rsp.rp.Error = &ResponseError{Code: 500, Message: "failed"}
Expand Down

0 comments on commit b9b7ef8

Please sign in to comment.