Skip to content

Commit

Permalink
chore: update ws.nextMode & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
denopink committed Apr 3, 2024
1 parent d5f5710 commit 77872d0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
34 changes: 17 additions & 17 deletions internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ type Websocket struct {
wg sync.WaitGroup
shutdown context.CancelFunc

mode ipMode
policy retry.Policy
decoder wrp.Decoder
encoder wrp.Encoder
Expand Down Expand Up @@ -155,11 +156,12 @@ func (ws *Websocket) Start() {

var ctx context.Context
ctx, ws.shutdown = context.WithCancel(context.Background())
ws.mode = ipv4
// Init retry policy, but it'll be reset on recurring successful connections.
ws.policy = ws.retryPolicyFactory.NewPolicy(ctx)
ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack)

go ws.readPump(ctx)
go ws.read(ctx)
}

// Stop stops the websocket connection.
Expand Down Expand Up @@ -191,22 +193,21 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error {
return err
}

func (ws *Websocket) readPump(ctx context.Context) {
func (ws *Websocket) read(ctx context.Context) {
ws.wg.Add(1)
defer ws.wg.Done()

mode := ws.nextMode(ipv4)
reconnect := true
for {
var dialErr error
if reconnect {
dialErr = ws.dial(ctx, mode)
dialErr = ws.dial(ctx)
}

if dialErr == nil {
// Read loop
for {
msg, err := ws.readMsg(ctx, mode)
msg, err := ws.readMsg(ctx)
// If a reconnect was attempted but failed, ErrClosed will be found
// in this error list and a reconnect should be attempted again.
reconnect = errors.Is(err, ErrClosed)
Expand All @@ -224,7 +225,6 @@ func (ws *Websocket) readPump(ctx context.Context) {
return
}

mode = ws.nextMode(mode)
next, _ := ws.policy.Next()

select {
Expand All @@ -235,14 +235,14 @@ func (ws *Websocket) readPump(ctx context.Context) {
}
}

func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) {
func (ws *Websocket) readMsg(ctx context.Context) (msg *wrp.Message, err error) {
defer func() {
if err != nil {
// The websocket either failed to read, gave us an unexpected message or a message
// that could not be decoded. Attempt to reconnect.
// If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect.
// Otherwise, ErrClosed will not be added to the error list.
err = errors.Join(err, ws.dial(ctx, mode))
err = errors.Join(err, ws.dial(ctx))
}
}()

Expand Down Expand Up @@ -277,8 +277,9 @@ func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message
return
}

func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) {
func (ws *Websocket) dial(ctx context.Context) (err error) {
var (
mode = ws.nextMode()
conn *nhws.Conn
resp *http.Response
)
Expand Down Expand Up @@ -372,6 +373,7 @@ func (rt *custRT) RoundTrip(r *http.Request) (*http.Response, error) {

// getRT returns a custom RoundTripper for the WS connection.
func (ws *Websocket) getRT(mode ipMode) *custRT {

dialer := &net.Dialer{
Timeout: ws.connectTimeout,
KeepAlive: ws.keepAliveInterval,
Expand All @@ -394,16 +396,14 @@ func (ws *Websocket) getRT(mode ipMode) *custRT {
}
}

func (ws *Websocket) nextMode(mode ipMode) ipMode {
if mode == ipv4 && ws.withIPv6 {
return ipv6
}

if mode == ipv6 && ws.withIPv4 {
return ipv4
func (ws *Websocket) nextMode() ipMode {
if ws.mode == ipv4 && ws.withIPv6 {
ws.mode = ipv6
} else if ws.mode == ipv6 && ws.withIPv4 {
ws.mode = ipv4
}

return mode
return ws.mode
}

func limit(s string) string {
Expand Down
3 changes: 2 additions & 1 deletion internal/websocket/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ func TestNextMode(t *testing.T) {
URL("http://example.com"),
)
got, err := New(opts...)
got.mode = tc.mode
require.NoError(err)
require.NotNil(got)
assert.Equal(tc.expected, got.nextMode(tc.mode))
assert.Equal(tc.expected, got.nextMode())
})
}
}
Expand Down

0 comments on commit 77872d0

Please sign in to comment.