From 55d40449fc436fd409925547437547c7aef56847 Mon Sep 17 00:00:00 2001 From: Jacob Weinstock Date: Mon, 11 Sep 2023 14:06:48 -0600 Subject: [PATCH] Check response in Open method: Check status code and response payload in the Open method to validate the endpoint is conformant with the response contract. Update checking of the response error code as well as if it is nil. This will make sure we dont error out when a response contains a value for the error instead of just nil. Signed-off-by: Jacob Weinstock --- providers/rpc/doc.go | 2 +- providers/rpc/http.go | 2 +- providers/rpc/rpc.go | 21 ++++++++++++++++----- providers/rpc/rpc_test.go | 6 +----- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/providers/rpc/doc.go b/providers/rpc/doc.go index b56e285e..edf20025 100644 --- a/providers/rpc/doc.go +++ b/providers/rpc/doc.go @@ -3,6 +3,6 @@ Package rpc is a provider that defines an HTTP request/response contract for han It allows users a simple way to interoperate with an existing/bespoke out-of-band management solution. The rpc provider request/response payloads are modeled after JSON-RPC 2.0, but are not JSON-RPC 2.0 -compliant so as to allow for more flexibility. +compliant so as to allow for more flexibility and interoperability with existing systems. */ package rpc diff --git a/providers/rpc/http.go b/providers/rpc/http.go index c74b8642..b51d00a2 100644 --- a/providers/rpc/http.go +++ b/providers/rpc/http.go @@ -45,7 +45,7 @@ func (p *Provider) createRequest(ctx context.Context, rp RequestPayload) (*http. return req, nil } -func (p *Provider) handleResponse(resp *http.Response, reqKeysAndValues []interface{}) (ResponsePayload, error) { +func (p *Provider) handleResponse(resp *http.Response, reqKeysAndValues []any) (ResponsePayload, error) { kvs := reqKeysAndValues defer func() { if !p.LogNotificationsDisabled { diff --git a/providers/rpc/rpc.go b/providers/rpc/rpc.go index 05c799c2..96a09aa0 100644 --- a/providers/rpc/rpc.go +++ b/providers/rpc/rpc.go @@ -3,6 +3,7 @@ package rpc import ( "bytes" "context" + "encoding/json" "errors" "fmt" "hash" @@ -181,16 +182,26 @@ func (p *Provider) Open(ctx context.Context) error { return err } p.listenerURL = u - testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), nil) + var buf bytes.Buffer + _ = json.NewEncoder(&buf).Encode(RequestPayload{}) + testReq, err := http.NewRequestWithContext(ctx, p.Opts.Request.HTTPMethod, p.listenerURL.String(), bytes.NewReader(buf.Bytes())) if err != nil { return err } // test that we can communicate with the rpc consumer. - // and that it responses with the spec contract (Response{}). resp, err := p.Client.Do(testReq) if err != nil { return err } + if resp.StatusCode >= http.StatusInternalServerError { + return fmt.Errorf("issue on the rpc consumer side, status code: %d", resp.StatusCode) + } + + // test that the consumer responses with the expected contract (ResponsePayload{}). + var res ResponsePayload + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return fmt.Errorf("issue with the rpc consumer response: %v", err) + } return resp.Body.Close() } @@ -216,7 +227,7 @@ func (p *Provider) BootDeviceSet(ctx context.Context, bootDevice string, setPers if err != nil { return false, err } - if resp.Error != nil { + if resp.Error != nil && resp.Error.Code != 0 { return false, fmt.Errorf("error from rpc consumer: %v", resp.Error) } @@ -239,7 +250,7 @@ func (p *Provider) PowerSet(ctx context.Context, state string) (ok bool, err err if err != nil { return ok, err } - if resp.Error != nil { + if resp.Error != nil && resp.Error.Code != 0 { return ok, fmt.Errorf("error from rpc consumer: %v", resp.Error) } @@ -260,7 +271,7 @@ func (p *Provider) PowerStateGet(ctx context.Context) (state string, err error) if err != nil { return "", err } - if resp.Error != nil { + if resp.Error != nil && resp.Error.Code != 0 { return "", fmt.Errorf("error from rpc consumer: %v", resp.Error) } diff --git a/providers/rpc/rpc_test.go b/providers/rpc/rpc_test.go index 78abf556..5aecf053 100644 --- a/providers/rpc/rpc_test.go +++ b/providers/rpc/rpc_test.go @@ -180,11 +180,7 @@ func TestServerErrors(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() c := New(svr.URL, "127.0.0.1", Secrets{SHA256: {"superSecret1"}}) - if err := c.Open(ctx); err != nil { - t.Fatal(err) - } - _, err := c.PowerStateGet(ctx) - if err == nil { + if err := c.Open(ctx); err == nil { t.Fatal("expected error, got none") } })