Skip to content

Commit

Permalink
make Gateway.Open wait until ready event is received
Browse files Browse the repository at this point in the history
  • Loading branch information
topi314 committed Oct 4, 2023
1 parent 44dfd8e commit 6bca87f
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions gateway/gateway_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ func (g *gatewayImpl) open(ctx context.Context) error {
g.config.Logger.Debug(g.formatLogs("opening gateway connection"))

g.connMu.Lock()
defer g.connMu.Unlock()
if g.conn != nil {
g.connMu.Unlock()
return discord.ErrGatewayAlreadyConnected
}
g.status = StatusConnecting
Expand All @@ -120,6 +120,7 @@ func (g *gatewayImpl) open(ctx context.Context) error {
}

g.config.Logger.Error(g.formatLogsf("error connecting to the gateway. url: %s, error: %s, body: %s", gatewayURL, err, body))
g.connMu.Unlock()
return err
}

Expand All @@ -128,13 +129,27 @@ func (g *gatewayImpl) open(ctx context.Context) error {
})

g.conn = conn
g.connMu.Unlock()

// reset rate limiter when connecting
g.config.RateLimiter.Reset()

g.status = StatusWaitingForHello

go g.listen(conn)
readyChan := make(chan error)
go g.listen(conn, readyChan)

select {
case <-ctx.Done():
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
g.Close(closeCtx)
return ctx.Err()
case err = <-readyChan:
if err != nil {
return fmt.Errorf("failed to open gateway connection: %w", err)
}
}

return nil
}
Expand Down Expand Up @@ -226,6 +241,13 @@ func (g *gatewayImpl) reconnectTry(ctx context.Context, try int) error {
}

if err := g.open(ctx); err != nil {
var closeError *websocket.CloseError
if errors.As(err, &closeError) {
closeCode := CloseEventCodeByCode(closeError.Code)
if !closeCode.Reconnect {
return err
}
}
if errors.Is(err, discord.ErrGatewayAlreadyConnected) {
return err
}
Expand Down Expand Up @@ -279,7 +301,7 @@ func (g *gatewayImpl) sendHeartbeat() {
g.lastHeartbeatSent = time.Now().UTC()
}

func (g *gatewayImpl) identify() {
func (g *gatewayImpl) identify() error {
g.status = StatusIdentifying
g.config.Logger.Debug(g.formatLogs("sending Identify command..."))

Expand All @@ -298,12 +320,13 @@ func (g *gatewayImpl) identify() {
}

if err := g.Send(context.TODO(), OpcodeIdentify, identify); err != nil {
g.config.Logger.Error(g.formatLogs("error sending Identify command err: ", err))
return err
}
g.status = StatusWaitingForReady
return nil
}

func (g *gatewayImpl) resume() {
func (g *gatewayImpl) resume() error {
g.status = StatusResuming
resume := MessageDataResume{
Token: g.token,
Expand All @@ -313,16 +336,22 @@ func (g *gatewayImpl) resume() {

g.config.Logger.Debug(g.formatLogs("sending Resume command..."))
if err := g.Send(context.TODO(), OpcodeResume, resume); err != nil {
g.config.Logger.Error(g.formatLogs("error sending resume command err: ", err))
return err
}
return nil
}

func (g *gatewayImpl) listen(conn *websocket.Conn) {
func (g *gatewayImpl) listen(conn *websocket.Conn, readyChan chan<- error) {
defer g.config.Logger.Debug(g.formatLogs("exiting listen goroutine..."))
loop:
for {
mt, data, err := conn.ReadMessage()
if err != nil {
if g.status != StatusReady {
readyChan <- err
close(readyChan)
break loop
}
g.connMu.Lock()
sameConnection := g.conn == conn
g.connMu.Unlock()
Expand Down Expand Up @@ -382,9 +411,14 @@ loop:
go g.heartbeat()

if g.config.LastSequenceReceived == nil || g.config.SessionID == nil {
g.identify()
err = g.identify()
} else {
g.resume()
err = g.resume()
}
if err != nil {
readyChan <- err
close(readyChan)
return
}

case OpcodeDispatch:
Expand Down Expand Up @@ -418,6 +452,10 @@ loop:
})
}
g.eventHandlerFunc(message.T, message.S, g.config.ShardID, eventData)
if _, ok = eventData.(EventReady); ok {
readyChan <- nil
close(readyChan)
}

case OpcodeHeartbeat:
g.sendHeartbeat()
Expand Down

0 comments on commit 6bca87f

Please sign in to comment.