Skip to content

Commit

Permalink
Merge pull request #203 from xmidt-org/denopink/feat/ws-count-pings-a…
Browse files Browse the repository at this point in the history
…s-activity

feat: count ping as websocket activity
  • Loading branch information
denopink authored Jul 31, 2024
2 parents 04755df + c3c7f4a commit 31387eb
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ func (ws *Websocket) run(ctx context.Context) {
mode := ws.nextMode(ipv4)

policy := ws.retryPolicyFactory.NewPolicy(ctx)
inactivityTimeout := time.After(ws.inactivityTimeout)

for {
var next time.Duration
Expand Down Expand Up @@ -258,6 +257,7 @@ func (ws *Websocket) run(ctx context.Context) {
// Store the connection so writing can take place.
ws.m.Lock()
ws.conn = conn
activity := make(chan struct{})
ws.conn.SetPingListener((func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
Expand All @@ -270,7 +270,9 @@ func (ws *Websocket) run(ctx context.Context) {
})
})

inactivityTimeout = time.After(ws.inactivityTimeout)
if len(activity) == 0 {
activity <- struct{}{}
}
}))
ws.conn.SetPongListener(func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
Expand All @@ -289,22 +291,32 @@ func (ws *Websocket) run(ctx context.Context) {
// Read loop
for {
var msg wrp.Message
ctx, cancel := context.WithTimeout(ctx, ws.inactivityTimeout)
typ, reader, err := ws.conn.Reader(ctx)
if errors.Is(err, context.DeadlineExceeded) {
select {
case <-inactivityTimeout:
// inactivityTimeout occurred, continue with ws.read()'s error handling (connection will be closed).
default:
// Ping was received during ws.conn.Reader(), i.e.: inactivityTimeout was reset.
// Reset inactivityTimeout again for the next ws.conn.Reader().
inactivityTimeout = time.After(ws.inactivityTimeout)
cancel()
continue
ctx, cancel := context.WithCancelCause(ctx)

// Monitor for activity.
go func() {
inactivityTimeout := time.After(ws.inactivityTimeout)
loop1:
for {
select {
case <-ctx.Done():
break loop1
case <-activity:
inactivityTimeout = time.After(ws.inactivityTimeout)
case <-inactivityTimeout:
// inactivityTimeout occurred, cancel the context.
cancel(context.DeadlineExceeded)
break loop1
}
}
} else if errors.Is(err, context.Canceled) {
// Parent context has been canceled.
cancel()
}()

typ, reader, err := ws.conn.Reader(ctx)
ctxErr := context.Cause(ctx)
err = errors.Join(err, ctxErr)
// If ctxErr is context.Canceled then the parent context has been canceled.
if errors.Is(ctxErr, context.Canceled) {
cancel(nil)
break
}

Expand All @@ -318,7 +330,7 @@ func (ws *Websocket) run(ctx context.Context) {
}

// Cancel ws.conn.Reader()'s context after wrp decoding.
cancel()
cancel(nil)
if err != nil {
ws.m.Lock()
ws.conn = nil
Expand Down

0 comments on commit 31387eb

Please sign in to comment.