diff --git a/appservice/appservice.go b/appservice/appservice.go index 780cd242..be9749bf 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -105,6 +105,8 @@ type AppService struct { botClient *mautrix.Client botIntent *IntentAPI + DefaultHTTPRetries int + clients map[id.UserID]*mautrix.Client clientsLock sync.RWMutex intents map[id.UserID]*IntentAPI @@ -218,6 +220,7 @@ func (as *AppService) makeClient(userID id.UserID) *mautrix.Client { client.AppServiceUserID = userID client.Logger = as.Log.Sub(string(userID)) client.Client = as.HTTPClient + client.DefaultHTTPRetries = as.DefaultHTTPRetries as.clients[userID] = client return client } diff --git a/client.go b/client.go index 20b8638c..2436037d 100644 --- a/client.go +++ b/client.go @@ -27,6 +27,11 @@ type Logger interface { Debugfln(message string, args ...interface{}) } +type WarnLogger interface { + Logger + Warnfln(message string, args ...interface{}) +} + type Stringifiable interface { String() string } @@ -45,6 +50,10 @@ type Client struct { Logger Logger SyncPresence event.Presence + // Number of times that mautrix will retry any HTTP request + // if the request fails entirely or returns a HTTP gateway error (502-504) + DefaultHTTPRetries int + txnID int32 // The ?user_id= query parameter for application services. This must be set *prior* to calling a method. If this is empty, @@ -84,7 +93,7 @@ func DiscoverClientAPI(serverName string) (*ClientWellKnown, error) { } req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", DefaultUserAgent + " .well-known fetcher") + req.Header.Set("User-Agent", DefaultUserAgent+" .well-known fetcher") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -255,14 +264,19 @@ func (cli *Client) StopSync() { cli.incrementSyncingID() } -func (cli *Client) LogRequest(req *http.Request, body string) { +const logBodyContextKey = "fi.mau.mautrix.log_body" +const logRequestIDContextKey = "fi.mau.mautrix.request_id" + +func (cli *Client) LogRequest(req *http.Request) { if cli.Logger == nil { return } - if len(body) > 0 { - cli.Logger.Debugfln("%s %s %s", req.Method, req.URL.String(), body) + body, ok := req.Context().Value(logBodyContextKey).(string) + reqID, _ := req.Context().Value(logRequestIDContextKey).(int) + if ok && len(body) > 0 { + cli.Logger.Debugfln("req #%d: %s %s %s", reqID, req.Method, req.URL.String(), body) } else { - cli.Logger.Debugfln("%s %s", req.Method, req.URL.String()) + cli.Logger.Debugfln("req #%d: %s %s", reqID, req.Method, req.URL.String()) } } @@ -271,31 +285,27 @@ func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{ } type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - ResponseJSON interface{} - Context context.Context + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + Context context.Context + MaxAttempts int } -// MakeFullRequest makes a JSON HTTP request to the given URL. -// If "resBody" is not nil, the response body will be json.Unmarshalled into it. -// -// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along -// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned -// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. -func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { - var req *http.Request - var err error +var requestID int32 + +func (params *FullRequest) compileRequest() (*http.Request, error) { var logBody string - var reqBody io.Reader + reqBody := params.RequestBody if params.Context == nil { params.Context = context.Background() } if params.RequestJSON != nil { - var jsonStr []byte - jsonStr, err = json.Marshal(params.RequestJSON) + jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { return nil, HTTPError{ Message: "failed to marshal JSON", @@ -304,8 +314,13 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { } logBody = string(jsonStr) reqBody = bytes.NewBuffer(jsonStr) + } else if params.RequestLength > 0 && params.RequestBody != nil { + logBody = fmt.Sprintf("%d bytes", params.RequestLength) } - req, err = http.NewRequestWithContext(params.Context, params.Method, params.URL, reqBody) + ctx := context.WithValue(params.Context, logBodyContextKey, logBody) + reqID := atomic.AddInt32(&requestID, 1) + ctx = context.WithValue(ctx, logRequestIDContextKey, int(reqID)) + req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) if err != nil { return nil, HTTPError{ Message: "failed to create request", @@ -315,19 +330,74 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { if params.Headers != nil { req.Header = params.Headers } - if len(logBody) > 0 { + if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } + if params.RequestLength > 0 && params.RequestBody != nil { + req.ContentLength = params.RequestLength + } + return req, nil +} + +// MakeFullRequest makes a JSON HTTP request to the given URL. +// If "resBody" is not nil, the response body will be json.Unmarshalled into it. +// +// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along +// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned +// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. +func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { + if params.MaxAttempts == 0 { + params.MaxAttempts = 1 + cli.DefaultHTTPRetries + } + req, err := params.compileRequest() + if err != nil { + return nil, err + } req.Header.Set("User-Agent", cli.UserAgent) if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } - cli.LogRequest(req, logBody) + return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON) +} + +func (cli *Client) logWarning(format string, args ...interface{}) { + warnLogger, ok := cli.Logger.(WarnLogger) + if ok { + warnLogger.Warnfln(format, args...) + } else { + cli.Logger.Debugfln(format, args...) + } +} + +func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}) ([]byte, error) { + reqID, _ := req.Context().Value(logRequestIDContextKey).(int) + if req.Body != nil { + if req.GetBody == nil { + cli.logWarning("Failed to get new body to retry request #%d: GetBody is nil", reqID) + return nil, cause + } + var err error + req.Body, err = req.GetBody() + if err != nil { + cli.logWarning("Failed to get new body to retry request #%d: %v", reqID, err) + return nil, cause + } + } + cli.logWarning("Request #%d failed: %v, retrying in %d seconds", reqID, cause, int(backoff.Seconds())) + time.Sleep(backoff) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON) +} + +func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}) ([]byte, error) { + cli.LogRequest(req) res, err := cli.Client.Do(req) if res != nil { defer res.Body.Close() } if err != nil { + if retries > 0 { + return cli.doRetry(req, err, retries, backoff, responseJSON) + } return nil, HTTPError{ Request: req, Response: res, @@ -337,6 +407,10 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { } } + if retries > 0 && (res.StatusCode == http.StatusBadGateway || res.StatusCode == http.StatusServiceUnavailable || res.StatusCode == http.StatusGatewayTimeout) { + return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON) + } + contents, err := ioutil.ReadAll(res.Body) if err != nil { return nil, HTTPError{ @@ -361,11 +435,11 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { } } - if params.ResponseJSON == nil { + if responseJSON == nil { return contents, nil } - if err = json.Unmarshal(contents, ¶ms.ResponseJSON); err != nil { + if err = json.Unmarshal(contents, &responseJSON); err != nil { return nil, HTTPError{ Request: req, Response: res, @@ -411,7 +485,14 @@ func (cli *Client) SyncRequest(timeout int, since, filterID string, fullState bo query["full_state"] = "true" } urlPath := cli.BuildURLWithQuery(URLPath{"sync"}, query) - _, err = cli.MakeFullRequest(FullRequest{Method: "GET", URL: urlPath, ResponseJSON: &resp, Context: ctx}) + _, err = cli.MakeFullRequest(FullRequest{ + Method: http.MethodGet, + URL: urlPath, + ResponseJSON: &resp, + Context: ctx, + // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. + MaxAttempts: 1, + }) return } @@ -889,7 +970,7 @@ func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string }) } -// UploadMedia uploads the given data to the content repository and returns an MXC URI. +// Upload uploads the given data to the content repository and returns an MXC URI. // // Deprecated: UploadMedia should be used instead. func (cli *Client) Upload(content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { @@ -917,63 +998,21 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { u.RawQuery = q.Encode() } - req, err := http.NewRequest("POST", u.String(), data.Content) - if err != nil { - return nil, HTTPError{ - WrappedError: err, - Message: "failed to create request", - } - } - + var headers http.Header if len(data.ContentType) > 0 { - req.Header.Set("Content-Type", data.ContentType) + headers = http.Header{"Content-Type": []string{data.ContentType}} } - req.Header.Set("Authorization", "Bearer "+cli.AccessToken) - req.Header.Set("User-Agent", cli.UserAgent) - req.ContentLength = data.ContentLength - cli.LogRequest(req, fmt.Sprintf("%d bytes", data.ContentLength)) - - res, err := cli.Client.Do(req) - if res != nil { - defer res.Body.Close() - } - if err != nil { - return nil, err - } - contents, err := ioutil.ReadAll(res.Body) - if err != nil { - return nil, HTTPError{ - Message: "failed to read upload response body", - WrappedError: err, - Response: res, - Request: req, - } - } - if res.StatusCode != 200 { - respErr := &RespError{} - if _ = json.Unmarshal(contents, respErr); respErr.ErrCode == "" { - respErr = nil - } - - return nil, HTTPError{ - Request: req, - Response: res, - RespError: respErr, - } - } var m RespMediaUpload - if err := json.Unmarshal(contents, &m); err != nil { - return nil, HTTPError{ - Request: req, - Response: res, - - Message: "failed to unmarshal upload response body", - ResponseBody: string(contents), - WrappedError: err, - } - } - return &m, nil + _, err := cli.MakeFullRequest(FullRequest{ + Method: http.MethodPost, + URL: u.String(), + Headers: headers, + RequestBody: data.Content, + RequestLength: data.ContentLength, + ResponseJSON: &m, + }) + return &m, err } // JoinedMembers returns a map of joined room members. See https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-joined-rooms