diff --git a/client.go b/client.go index 61772b6..9bca258 100644 --- a/client.go +++ b/client.go @@ -109,15 +109,7 @@ func (c *Client) SubscribeWithContext(ctx context.Context, stream string, handle } } } - - // Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method - var err error - if c.ReconnectStrategy != nil { - err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify) - } else { - err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify) - } - return err + return c.retryNotify(ctx, operation) } // SubscribeChan sends all events to the provided channel @@ -183,14 +175,7 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch go func() { defer c.cleanup(ch) - // Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method - var err error - if c.ReconnectStrategy != nil { - err = backoff.RetryNotify(operation, c.ReconnectStrategy, c.ReconnectNotify) - } else { - err = backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), c.ReconnectNotify) - } - + err := c.retryNotify(ctx, operation) // channel closed once connected if err != nil && !connected { errch <- err @@ -201,6 +186,18 @@ func (c *Client) SubscribeChanWithContext(ctx context.Context, stream string, ch return err } +func (c *Client) retryNotify(ctx context.Context, operation func() error) error { + var bk backoff.BackOff + // Apply user specified reconnection strategy or default to standard NewExponentialBackOff() reconnection method + if c.ReconnectStrategy != nil { + bk = c.ReconnectStrategy + } else { + bk = backoff.NewExponentialBackOff() + } + bk = backoff.WithContext(bk, ctx) + return backoff.RetryNotify(operation, bk, c.ReconnectNotify) +} + func (c *Client) startReadLoop(reader *EventStreamReader) (chan *Event, chan error) { outCh := make(chan *Event) erChan := make(chan error) diff --git a/client_test.go b/client_test.go index 3468df0..6786219 100644 --- a/client_test.go +++ b/client_test.go @@ -422,3 +422,29 @@ func TestSubscribeWithContextDone(t *testing.T) { assert.Equal(t, n1, n2) } + +func TestSubscribeWithContextAbortRetrier(t *testing.T) { + // Run a server that only responds with HTTP errors which will put the client into the + // backoff.RetryNotify loop. + const status = http.StatusBadGateway + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.Method, r.URL.String(), http.StatusText(status)) + w.WriteHeader(status) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c := NewClient(srv.URL) + c.ReconnectNotify = backoff.Notify(func(err error, d time.Duration) { + t.Logf("ReconnectNotify err: %v, duration: %s", err, d.String()) + // The client has processed the HTTP server error from above, so cancel the context + // for the SubscribeWithContext call. + cancel() + }) + + err := c.SubscribeWithContext(ctx, "test", func(msg *Event) { + t.Fatal("Received event when none was expected:", msg) + }) + require.Error(t, err) + assert.Regexp(t, `could not connect to stream: `+http.StatusText(status), err.Error()) +}