Skip to content

Commit

Permalink
Return appservice websocket close error from StartWebsocket
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Apr 26, 2021
1 parent e258075 commit 8cb0c0e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 37 deletions.
2 changes: 1 addition & 1 deletion appservice/appservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type AppService struct {
intentsLock sync.RWMutex

ws *websocket.Conn
StopWebsocket func()
StopWebsocket func(error)
WebsocketCommands chan WebsocketCommand
}

Expand Down
128 changes: 93 additions & 35 deletions appservice/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,98 @@ type WebsocketMessage struct {
WebsocketCommand
}

type MeowWebsocketCloseCode string

const (
MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down"
MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced"
)

var (
WebsocketManualStop = errors.New("the websocket was disconnected manually")
WebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection")
WebsocketUnknownError = errors.New("an unknown error occurred")
)

func (mwcc MeowWebsocketCloseCode) String() string {
switch mwcc {
case MeowServerShuttingDown:
return "the server is shutting down"
case MeowConnectionReplaced:
return "the connection was replaced by another client"
default:
return string(mwcc)
}
}

type CloseCommand struct {
Code int `json:"-"`
Command string `json:"command"`
Status MeowWebsocketCloseCode `json:"status"`
}

func (cc CloseCommand) Error() string {
return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String())
}

func parseCloseError(err error) error {
closeError := &websocket.CloseError{}
if !errors.As(err, &closeError) {
return err
}
var closeCommand CloseCommand
closeCommand.Code = closeError.Code
closeCommand.Command = "disconnect"
if len(closeError.Text) > 0 {
jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
if jsonErr != nil {
return err
}
}
if len(closeCommand.Status) == 0 {
if closeCommand.Code == 4001 {
closeCommand.Status = MeowConnectionReplaced
} else if closeCommand.Code == websocket.CloseServiceRestart {
closeCommand.Status = MeowServerShuttingDown
}
}
return &closeCommand
}

func (as *AppService) SendWebsocket(cmd WebsocketCommand) error {
if as.ws == nil {
return errors.New("websocket not connected")
}
return as.ws.WriteJSON(&cmd)
}

func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
defer stopFunc(WebsocketUnknownError)
for {
var msg WebsocketMessage
err := ws.ReadJSON(&msg)
if err != nil {
as.Log.Debugln("Error reading from websocket:", err)
stopFunc(parseCloseError(err))
return
}
if msg.Command == "" || msg.Command == "transaction" {
if as.Registration.EphemeralEvents && msg.EphemeralEvents != nil {
as.handleEvents(msg.EphemeralEvents, event.EphemeralEventType)
}
as.handleEvents(msg.Events, event.UnknownEventType)
} else if msg.Command == "connect" {
as.Log.Debugln("Websocket connect confirmation received")
} else {
select {
case as.WebsocketCommands <- msg.WebsocketCommand:
default:
as.Log.Warnln("Dropping websocket command %s %d / %s", msg.Command, msg.ReqID, msg.Data)
}
}
}
}

func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
parsed, err := url.Parse(baseURL)
if err != nil {
Expand All @@ -62,7 +147,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
}
ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
"User-Agent": []string{as.BotClient().UserAgent},
"User-Agent": []string{as.BotClient().UserAgent},
})
if resp != nil && resp.StatusCode >= 400 {
var errResp ErrorResponse
Expand All @@ -76,54 +161,27 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
return fmt.Errorf("failed to open websocket: %w", err)
}
if as.StopWebsocket != nil {
as.StopWebsocket()
as.StopWebsocket(WebsocketOverridden)
}
closeChan := make(chan struct{})
closeChan := make(chan error)
closeChanSync := sync.Once{}
stopFunc := func() {
stopFunc := func(err error) {
closeChanSync.Do(func() {
close(closeChan)
closeChan <- err
})
}
as.ws = ws
as.StopWebsocket = stopFunc
as.PrepareWebsocket()
as.Log.Debugln("Appservice transaction websocket connected")

go func() {
defer stopFunc()
for {
var msg WebsocketMessage
err := ws.ReadJSON(&msg)
if err != nil {
as.Log.Warnln("Error reading from websocket:", err)
return
}
if msg.Command == "" || msg.Command == "transaction" {
if as.Registration.EphemeralEvents && msg.EphemeralEvents != nil {
as.handleEvents(msg.EphemeralEvents, event.EphemeralEventType)
}
as.handleEvents(msg.Events, event.UnknownEventType)
} else if msg.Command == "connect" {
as.Log.Debugln("Websocket connect confirmation received")
} else if msg.Command == "disconnect" {
as.Log.Debugln("Websocket disconnect command received")
break
} else {
select {
case as.WebsocketCommands <- msg.WebsocketCommand:
default:
as.Log.Warnln("Dropping websocket command %s %d / %s", msg.Command, msg.ReqID, msg.Data)
}
}
}
}()
go as.consumeWebsocket(stopFunc, ws)

if onConnect != nil {
onConnect()
}

<-closeChan
closeErr := <-closeChan

err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
if err != nil && err != websocket.ErrCloseSent {
Expand All @@ -133,5 +191,5 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
if err != nil {
as.Log.Warnln("Error closing websocket:", err)
}
return nil
return closeErr
}
2 changes: 1 addition & 1 deletion version.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package mautrix

const Version = "v0.9.8"
const Version = "v0.9.9"

var DefaultUserAgent = "mautrix-go/" + Version

0 comments on commit 8cb0c0e

Please sign in to comment.