diff --git a/gateway/gateway_impl.go b/gateway/gateway_impl.go index d6a0d09a..4d1169d6 100644 --- a/gateway/gateway_impl.go +++ b/gateway/gateway_impl.go @@ -92,8 +92,8 @@ func (g *gatewayImpl) open(ctx context.Context) error { g.config.Logger.Debug(g.formatLogs("opening gateway connection")) g.connMu.Lock() - defer g.connMu.Unlock() if g.conn != nil { + g.connMu.Unlock() return discord.ErrGatewayAlreadyConnected } g.status = StatusConnecting @@ -120,6 +120,7 @@ func (g *gatewayImpl) open(ctx context.Context) error { } g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body)) + g.connMu.Unlock() return err } @@ -128,13 +129,27 @@ func (g *gatewayImpl) open(ctx context.Context) error { }) g.conn = conn + g.connMu.Unlock() // reset rate limiter when connecting g.config.RateLimiter.Reset() g.status = StatusWaitingForHello - go g.listen(conn) + readyChan := make(chan error) + go g.listen(conn, readyChan) + + select { + case <-ctx.Done(): + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + g.Close(closeCtx) + return ctx.Err() + case err = <-readyChan: + if err != nil { + return fmt.Errorf("failed to open gateway connection: %w", err) + } + } return nil } @@ -226,6 +241,13 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int) error { } if err := g.open(ctx); err != nil { + var closeError *websocket.CloseError + if errors.As(err, &closeError) { + closeCode := CloseEventCodeByCode(closeError.Code) + if !closeCode.Reconnect { + return err + } + } if errors.Is(err, discord.ErrGatewayAlreadyConnected) { return err } @@ -279,7 +301,7 @@ func (g *gatewayImpl) sendHeartbeat() { g.lastHeartbeatSent = time.Now().UTC() } -func (g *gatewayImpl) identify() { +func (g *gatewayImpl) identify() error { g.status = StatusIdentifying g.config.Logger.Debug(g.formatLogs("sending Identify command...")) @@ -298,12 +320,13 @@ func (g *gatewayImpl) identify() { } if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil { - g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err)) + return err } g.status = StatusWaitingForReady + return nil } -func (g *gatewayImpl) resume() { +func (g *gatewayImpl) resume() error { g.status = StatusResuming resume := MessageDataResume{ Token: g.token, @@ -313,16 +336,22 @@ func (g *gatewayImpl) resume() { g.config.Logger.Debug(g.formatLogs("sending Resume command...")) if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil { - g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err)) + return err } + return nil } -func (g *gatewayImpl) listen(conn *websocket.Conn) { +func (g *gatewayImpl) listen(conn *websocket.Conn, readyChan chan<- error) { defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine...")) loop: for { mt, data, err := conn.ReadMessage() if err != nil { + if g.status != StatusReady { + readyChan <- err + close(readyChan) + break loop + } g.connMu.Lock() sameConnection := g.conn == conn g.connMu.Unlock() @@ -382,9 +411,14 @@ loop: go g.heartbeat() if g.config.LastSequenceReceived == nil || g.config.SessionID == nil { - g.identify() + err = g.identify() } else { - g.resume() + err = g.resume() + } + if err != nil { + readyChan <- err + close(readyChan) + return } case OpcodeDispatch: @@ -418,6 +452,10 @@ loop: }) } g.eventHandlerFunc(message.T, message.S, g.config.ShardID, eventData) + if _, ok = eventData.(EventReady); ok { + readyChan <- nil + close(readyChan) + } case OpcodeHeartbeat: g.sendHeartbeat()