From 77872d04f2fa3727349ff329eb933f608ed0a80a Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Tue, 2 Apr 2024 21:18:42 -0400 Subject: [PATCH] chore: update ws.nextMode & tests --- internal/websocket/ws.go | 34 +++++++++++++++++----------------- internal/websocket/ws_test.go | 3 ++- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index 727a33c..482f0de 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -101,6 +101,7 @@ type Websocket struct { wg sync.WaitGroup shutdown context.CancelFunc + mode ipMode policy retry.Policy decoder wrp.Decoder encoder wrp.Encoder @@ -155,11 +156,12 @@ func (ws *Websocket) Start() { var ctx context.Context ctx, ws.shutdown = context.WithCancel(context.Background()) + ws.mode = ipv4 // Init retry policy, but it'll be reset on recurring successful connections. ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack) - go ws.readPump(ctx) + go ws.read(ctx) } // Stop stops the websocket connection. @@ -191,22 +193,21 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error { return err } -func (ws *Websocket) readPump(ctx context.Context) { +func (ws *Websocket) read(ctx context.Context) { ws.wg.Add(1) defer ws.wg.Done() - mode := ws.nextMode(ipv4) reconnect := true for { var dialErr error if reconnect { - dialErr = ws.dial(ctx, mode) + dialErr = ws.dial(ctx) } if dialErr == nil { // Read loop for { - msg, err := ws.readMsg(ctx, mode) + msg, err := ws.readMsg(ctx) // If a reconnect was attempted but failed, ErrClosed will be found // in this error list and a reconnect should be attempted again. reconnect = errors.Is(err, ErrClosed) @@ -224,7 +225,6 @@ func (ws *Websocket) readPump(ctx context.Context) { return } - mode = ws.nextMode(mode) next, _ := ws.policy.Next() select { @@ -235,14 +235,14 @@ func (ws *Websocket) readPump(ctx context.Context) { } } -func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) { +func (ws *Websocket) readMsg(ctx context.Context) (msg *wrp.Message, err error) { defer func() { if err != nil { // The websocket either failed to read, gave us an unexpected message or a message // that could not be decoded. Attempt to reconnect. // If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect. // Otherwise, ErrClosed will not be added to the error list. - err = errors.Join(err, ws.dial(ctx, mode)) + err = errors.Join(err, ws.dial(ctx)) } }() @@ -277,8 +277,9 @@ func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message return } -func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) { +func (ws *Websocket) dial(ctx context.Context) (err error) { var ( + mode = ws.nextMode() conn *nhws.Conn resp *http.Response ) @@ -372,6 +373,7 @@ func (rt *custRT) RoundTrip(r *http.Request) (*http.Response, error) { // getRT returns a custom RoundTripper for the WS connection. func (ws *Websocket) getRT(mode ipMode) *custRT { + dialer := &net.Dialer{ Timeout: ws.connectTimeout, KeepAlive: ws.keepAliveInterval, @@ -394,16 +396,14 @@ func (ws *Websocket) getRT(mode ipMode) *custRT { } } -func (ws *Websocket) nextMode(mode ipMode) ipMode { - if mode == ipv4 && ws.withIPv6 { - return ipv6 - } - - if mode == ipv6 && ws.withIPv4 { - return ipv4 +func (ws *Websocket) nextMode() ipMode { + if ws.mode == ipv4 && ws.withIPv6 { + ws.mode = ipv6 + } else if ws.mode == ipv6 && ws.withIPv4 { + ws.mode = ipv4 } - return mode + return ws.mode } func limit(s string) string { diff --git a/internal/websocket/ws_test.go b/internal/websocket/ws_test.go index b1d3024..5c24080 100644 --- a/internal/websocket/ws_test.go +++ b/internal/websocket/ws_test.go @@ -366,9 +366,10 @@ func TestNextMode(t *testing.T) { URL("http://example.com"), ) got, err := New(opts...) + got.mode = tc.mode require.NoError(err) require.NotNil(got) - assert.Equal(tc.expected, got.nextMode(tc.mode)) + assert.Equal(tc.expected, got.nextMode()) }) } }