From 8cb0c0ee0c519af0e3a5e050ad02bcec1fa65b4d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Apr 2021 23:52:25 +0300 Subject: [PATCH] Return appservice websocket close error from StartWebsocket --- appservice/appservice.go | 2 +- appservice/websocket.go | 128 ++++++++++++++++++++++++++++----------- version.go | 2 +- 3 files changed, 95 insertions(+), 37 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index be9749bf..e69a65ae 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -113,7 +113,7 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn - StopWebsocket func() + StopWebsocket func(error) WebsocketCommands chan WebsocketCommand } diff --git a/appservice/websocket.go b/appservice/websocket.go index 24031391..21700fd5 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -42,6 +42,64 @@ type WebsocketMessage struct { WebsocketCommand } +type MeowWebsocketCloseCode string + +const ( + MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down" + MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced" +) + +var ( + WebsocketManualStop = errors.New("the websocket was disconnected manually") + WebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection") + WebsocketUnknownError = errors.New("an unknown error occurred") +) + +func (mwcc MeowWebsocketCloseCode) String() string { + switch mwcc { + case MeowServerShuttingDown: + return "the server is shutting down" + case MeowConnectionReplaced: + return "the connection was replaced by another client" + default: + return string(mwcc) + } +} + +type CloseCommand struct { + Code int `json:"-"` + Command string `json:"command"` + Status MeowWebsocketCloseCode `json:"status"` +} + +func (cc CloseCommand) Error() string { + return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String()) +} + +func parseCloseError(err error) error { + closeError := &websocket.CloseError{} + if !errors.As(err, &closeError) { + return err + } + var closeCommand CloseCommand + closeCommand.Code = closeError.Code + closeCommand.Command = "disconnect" + if len(closeError.Text) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) + if jsonErr != nil { + return err + } + } + if len(closeCommand.Status) == 0 { + if closeCommand.Code == 4001 { + closeCommand.Status = MeowConnectionReplaced + } else if closeCommand.Code == websocket.CloseServiceRestart { + closeCommand.Status = MeowServerShuttingDown + } + } + return &closeCommand +} + func (as *AppService) SendWebsocket(cmd WebsocketCommand) error { if as.ws == nil { return errors.New("websocket not connected") @@ -49,6 +107,33 @@ func (as *AppService) SendWebsocket(cmd WebsocketCommand) error { return as.ws.WriteJSON(&cmd) } +func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { + defer stopFunc(WebsocketUnknownError) + for { + var msg WebsocketMessage + err := ws.ReadJSON(&msg) + if err != nil { + as.Log.Debugln("Error reading from websocket:", err) + stopFunc(parseCloseError(err)) + return + } + if msg.Command == "" || msg.Command == "transaction" { + if as.Registration.EphemeralEvents && msg.EphemeralEvents != nil { + as.handleEvents(msg.EphemeralEvents, event.EphemeralEventType) + } + as.handleEvents(msg.Events, event.UnknownEventType) + } else if msg.Command == "connect" { + as.Log.Debugln("Websocket connect confirmation received") + } else { + select { + case as.WebsocketCommands <- msg.WebsocketCommand: + default: + as.Log.Warnln("Dropping websocket command %s %d / %s", msg.Command, msg.ReqID, msg.Data) + } + } + } +} + func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { parsed, err := url.Parse(baseURL) if err != nil { @@ -62,7 +147,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { } ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + "User-Agent": []string{as.BotClient().UserAgent}, }) if resp != nil && resp.StatusCode >= 400 { var errResp ErrorResponse @@ -76,13 +161,13 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { return fmt.Errorf("failed to open websocket: %w", err) } if as.StopWebsocket != nil { - as.StopWebsocket() + as.StopWebsocket(WebsocketOverridden) } - closeChan := make(chan struct{}) + closeChan := make(chan error) closeChanSync := sync.Once{} - stopFunc := func() { + stopFunc := func(err error) { closeChanSync.Do(func() { - close(closeChan) + closeChan <- err }) } as.ws = ws @@ -90,40 +175,13 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.PrepareWebsocket() as.Log.Debugln("Appservice transaction websocket connected") - go func() { - defer stopFunc() - for { - var msg WebsocketMessage - err := ws.ReadJSON(&msg) - if err != nil { - as.Log.Warnln("Error reading from websocket:", err) - return - } - if msg.Command == "" || msg.Command == "transaction" { - if as.Registration.EphemeralEvents && msg.EphemeralEvents != nil { - as.handleEvents(msg.EphemeralEvents, event.EphemeralEventType) - } - as.handleEvents(msg.Events, event.UnknownEventType) - } else if msg.Command == "connect" { - as.Log.Debugln("Websocket connect confirmation received") - } else if msg.Command == "disconnect" { - as.Log.Debugln("Websocket disconnect command received") - break - } else { - select { - case as.WebsocketCommands <- msg.WebsocketCommand: - default: - as.Log.Warnln("Dropping websocket command %s %d / %s", msg.Command, msg.ReqID, msg.Data) - } - } - } - }() + go as.consumeWebsocket(stopFunc, ws) if onConnect != nil { onConnect() } - <-closeChan + closeErr := <-closeChan err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) if err != nil && err != websocket.ErrCloseSent { @@ -133,5 +191,5 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { if err != nil { as.Log.Warnln("Error closing websocket:", err) } - return nil + return closeErr } diff --git a/version.go b/version.go index 31b425a4..9aca2ef7 100644 --- a/version.go +++ b/version.go @@ -1,5 +1,5 @@ package mautrix -const Version = "v0.9.8" +const Version = "v0.9.9" var DefaultUserAgent = "mautrix-go/" + Version