From a36b8facd8504110f182cb1c1d041c7b3fbd3eaf Mon Sep 17 00:00:00 2001 From: Jacob Weinstock Date: Mon, 18 Sep 2023 18:16:03 -0600 Subject: [PATCH] Fix copying of request body: The request body was being dropped after the io.Copy causing issues with further reading of it. Add a PingMethod for the Open call. This way we send something for the client to interpret instead of nothing. Update tests. Signed-off-by: Jacob Weinstock --- providers/rpc/http_test.go | 29 +++----------------------- providers/rpc/logging.go | 18 +++++++--------- providers/rpc/payload.go | 1 + providers/rpc/rpc.go | 42 ++++++++++++++++++-------------------- providers/rpc/rpc_test.go | 2 +- 5 files changed, 32 insertions(+), 60 deletions(-) diff --git a/providers/rpc/http_test.go b/providers/rpc/http_test.go index c211ff62..14dcf801 100644 --- a/providers/rpc/http_test.go +++ b/providers/rpc/http_test.go @@ -48,7 +48,9 @@ func TestRequestKVS(t *testing.T) { } for name, tc := range tests { t.Run(name, func(t *testing.T) { - kvs := requestKVS(tc.req) + buf := new(bytes.Buffer) + _, _ = io.Copy(buf, tc.req.Body) + kvs := requestKVS(tc.req.Method, tc.req.URL.String(), tc.req.Header, buf) if diff := cmp.Diff(kvs, tc.expected); diff != "" { t.Fatalf("requestKVS() mismatch (-want +got):\n%s", diff) } @@ -56,31 +58,6 @@ func TestRequestKVS(t *testing.T) { } } -func TestRequestKVSOneOffs(t *testing.T) { - t.Run("nil body", func(t *testing.T) { - req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", nil) - got := requestKVS(req) - if diff := cmp.Diff(got, []interface{}{"request", requestDetails{}}); diff != "" { - t.Logf("got: %+v", got) - t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff) - } - }) - t.Run("nil request", func(t *testing.T) { - if diff := cmp.Diff(requestKVS(nil), []interface{}{"request", requestDetails{}}); diff != "" { - t.Fatalf("requestKVS(nil) mismatch (-want +got):\n%s", diff) - } - }) - - t.Run("failed to unmarshal body", func(t *testing.T) { - req, _ := http.NewRequestWithContext(context.Background(), http.MethodPost, "http://example.com", bytes.NewBufferString("invalid")) - got := requestKVS(req) - if diff := cmp.Diff(got, []interface{}{"request", requestDetails{URL: "http://example.com", Method: http.MethodPost, Headers: http.Header{}}}); diff != "" { - t.Logf("got: %+v", got) - t.Fatalf("requestKVS(req) mismatch (-want +got):\n%s", diff) - } - }) -} - func TestResponseKVS(t *testing.T) { tests := map[string]struct { resp *http.Response diff --git a/providers/rpc/logging.go b/providers/rpc/logging.go index 5e689c68..a4b18757 100644 --- a/providers/rpc/logging.go +++ b/providers/rpc/logging.go @@ -3,7 +3,6 @@ package rpc import ( "bytes" "encoding/json" - "io" "net/http" ) @@ -21,20 +20,17 @@ type responseDetails struct { } // requestKVS returns a slice of key, value sets. Used for logging. -func requestKVS(req *http.Request) []interface{} { +func requestKVS(method, url string, headers http.Header, body *bytes.Buffer) []interface{} { var r requestDetails - if req != nil && req.Body != nil { + if body.Len() > 0 { var p RequestPayload - reqBody, err := io.ReadAll(req.Body) - if err == nil { - req.Body = io.NopCloser(bytes.NewBuffer(reqBody)) - _ = json.Unmarshal(reqBody, &p) - } + _ = json.Unmarshal(body.Bytes(), &p) + r = requestDetails{ Body: p, - Headers: req.Header, - URL: req.URL.String(), - Method: req.Method, + Headers: headers, + URL: url, + Method: method, } } diff --git a/providers/rpc/payload.go b/providers/rpc/payload.go index 560dc2fe..44a046bd 100644 --- a/providers/rpc/payload.go +++ b/providers/rpc/payload.go @@ -9,6 +9,7 @@ const ( PowerSetMethod Method = "setPowerState" PowerGetMethod Method = "getPowerState" VirtualMediaMethod Method = "setVirtualMedia" + PingMethod Method = "ping" ) // RequestPayload is the payload sent to the ConsumerURL. diff --git a/providers/rpc/rpc.go b/providers/rpc/rpc.go index 4651caaa..fd2eaa89 100644 --- a/providers/rpc/rpc.go +++ b/providers/rpc/rpc.go @@ -3,7 +3,6 @@ package rpc import ( "bytes" "context" - "encoding/json" "errors" "fmt" "hash" @@ -182,27 +181,16 @@ func (p *Provider) Open(ctx context.Context) error { return err } p.listenerURL = u - buf := new(bytes.Buffer) - _ = json.NewEncoder(buf).Encode(RequestPayload{}) - testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), buf) - if err != nil { - return err - } - // test that we can communicate with the rpc consumer. - resp, err := p.Client.Do(testReq) - if err != nil { - return err - } - if resp.StatusCode >= http.StatusInternalServerError { - return fmt.Errorf("issue on the rpc consumer side, status code: %d", resp.StatusCode) - } - // test that the consumer responses with the expected contract (ResponsePayload{}). - if err := json.NewDecoder(resp.Body).Decode(&ResponsePayload{}); err != nil { - return fmt.Errorf("issue with the rpc consumer response: %v", err) + if _, err = p.process(ctx, RequestPayload{ + ID: time.Now().UnixNano(), + Host: p.Host, + Method: PingMethod, + }); err != nil { + return err } - return resp.Body.Close() + return nil } // Close a connection to the rpc consumer. @@ -274,7 +262,12 @@ func (p *Provider) PowerStateGet(ctx context.Context) (state string, err error) return "", fmt.Errorf("error from rpc consumer: %v", resp.Error) } - return resp.Result.(string), nil + s, ok := resp.Result.(string) + if !ok { + return "", fmt.Errorf("expected result equal to type string, got: %T", resp.Result) + } + + return s, nil } // process is the main function for the roundtrip of rpc calls to the ConsumerURL. @@ -292,9 +285,14 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl // create the signature payload reqBuf := new(bytes.Buffer) - if _, err := io.Copy(reqBuf, req.Body); err != nil { + reqBody, err := req.GetBody() + if err != nil { + return ResponsePayload{}, fmt.Errorf("failed to get request body: %w", err) + } + if _, err := io.Copy(reqBuf, reqBody); err != nil { return ResponsePayload{}, fmt.Errorf("failed to read request body: %w", err) } + headersForSig := http.Header{} for _, h := range p.Opts.Signature.IncludedPayloadHeaders { if val := req.Header.Get(h); val != "" { @@ -321,7 +319,7 @@ func (p *Provider) process(ctx context.Context, rp RequestPayload) (ResponsePayl } // request/response round trip. - kvs := requestKVS(req) + kvs := requestKVS(req.Method, req.URL.String(), req.Header, reqBuf) kvs = append(kvs, []interface{}{"host", p.Host, "method", rp.Method, "consumerURL", p.ConsumerURL}...) if rp.Params != nil { kvs = append(kvs, []interface{}{"params", rp.Params}...) diff --git a/providers/rpc/rpc_test.go b/providers/rpc/rpc_test.go index 5aecf053..61401324 100644 --- a/providers/rpc/rpc_test.go +++ b/providers/rpc/rpc_test.go @@ -131,7 +131,7 @@ func TestPowerStateGet(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { rsp := testConsumer{ - rp: ResponsePayload{Result: tc.powerState}, + rp: ResponsePayload{ID: 123, Host: "127.0.1.1", Result: tc.powerState}, } if tc.shouldErr { rsp.rp.Error = &ResponseError{Code: 500, Message: "failed"}