Skip to content

Commit

Permalink
Add option for automatic HTTP request retries on network error
Browse files Browse the repository at this point in the history
The DefaultHTTPRetries option in Client and AppService sets the number
of retries that *all* HTTP requests do if they encounter a network or
gateway (502-504) error. The default is 0, which means no retries.

It can also be set per-request with FullRequest's MaxAttempts field.
That field starts at 1 instead of 0 so that the default of 0 means
using the default value from the Client.
  • Loading branch information
tulir committed Apr 15, 2021
1 parent b0fac45 commit 06fa15c
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 83 deletions.
3 changes: 3 additions & 0 deletions appservice/appservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ type AppService struct {
botClient *mautrix.Client
botIntent *IntentAPI

DefaultHTTPRetries int

clients map[id.UserID]*mautrix.Client
clientsLock sync.RWMutex
intents map[id.UserID]*IntentAPI
Expand Down Expand Up @@ -218,6 +220,7 @@ func (as *AppService) makeClient(userID id.UserID) *mautrix.Client {
client.AppServiceUserID = userID
client.Logger = as.Log.Sub(string(userID))
client.Client = as.HTTPClient
client.DefaultHTTPRetries = as.DefaultHTTPRetries
as.clients[userID] = client
return client
}
Expand Down
205 changes: 122 additions & 83 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ type Logger interface {
Debugfln(message string, args ...interface{})
}

type WarnLogger interface {
Logger
Warnfln(message string, args ...interface{})
}

type Stringifiable interface {
String() string
}
Expand All @@ -45,6 +50,10 @@ type Client struct {
Logger Logger
SyncPresence event.Presence

// Number of times that mautrix will retry any HTTP request
// if the request fails entirely or returns a HTTP gateway error (502-504)
DefaultHTTPRetries int

txnID int32

// The ?user_id= query parameter for application services. This must be set *prior* to calling a method. If this is empty,
Expand Down Expand Up @@ -84,7 +93,7 @@ func DiscoverClientAPI(serverName string) (*ClientWellKnown, error) {
}

req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", DefaultUserAgent + " .well-known fetcher")
req.Header.Set("User-Agent", DefaultUserAgent+" .well-known fetcher")

client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
Expand Down Expand Up @@ -255,14 +264,19 @@ func (cli *Client) StopSync() {
cli.incrementSyncingID()
}

func (cli *Client) LogRequest(req *http.Request, body string) {
const logBodyContextKey = "fi.mau.mautrix.log_body"
const logRequestIDContextKey = "fi.mau.mautrix.request_id"

func (cli *Client) LogRequest(req *http.Request) {
if cli.Logger == nil {
return
}
if len(body) > 0 {
cli.Logger.Debugfln("%s %s %s", req.Method, req.URL.String(), body)
body, ok := req.Context().Value(logBodyContextKey).(string)
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
if ok && len(body) > 0 {
cli.Logger.Debugfln("req #%d: %s %s %s", reqID, req.Method, req.URL.String(), body)
} else {
cli.Logger.Debugfln("%s %s", req.Method, req.URL.String())
cli.Logger.Debugfln("req #%d: %s %s", reqID, req.Method, req.URL.String())
}
}

Expand All @@ -271,31 +285,27 @@ func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{
}

type FullRequest struct {
Method string
URL string
Headers http.Header
RequestJSON interface{}
ResponseJSON interface{}
Context context.Context
Method string
URL string
Headers http.Header
RequestJSON interface{}
RequestBody io.Reader
RequestLength int64
ResponseJSON interface{}
Context context.Context
MaxAttempts int
}

// MakeFullRequest makes a JSON HTTP request to the given URL.
// If "resBody" is not nil, the response body will be json.Unmarshalled into it.
//
// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along
// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned
// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError.
func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
var req *http.Request
var err error
var requestID int32

func (params *FullRequest) compileRequest() (*http.Request, error) {
var logBody string
var reqBody io.Reader
reqBody := params.RequestBody
if params.Context == nil {
params.Context = context.Background()
}
if params.RequestJSON != nil {
var jsonStr []byte
jsonStr, err = json.Marshal(params.RequestJSON)
jsonStr, err := json.Marshal(params.RequestJSON)
if err != nil {
return nil, HTTPError{
Message: "failed to marshal JSON",
Expand All @@ -304,8 +314,13 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
}
logBody = string(jsonStr)
reqBody = bytes.NewBuffer(jsonStr)
} else if params.RequestLength > 0 && params.RequestBody != nil {
logBody = fmt.Sprintf("%d bytes", params.RequestLength)
}
req, err = http.NewRequestWithContext(params.Context, params.Method, params.URL, reqBody)
ctx := context.WithValue(params.Context, logBodyContextKey, logBody)
reqID := atomic.AddInt32(&requestID, 1)
ctx = context.WithValue(ctx, logRequestIDContextKey, int(reqID))
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
if err != nil {
return nil, HTTPError{
Message: "failed to create request",
Expand All @@ -315,19 +330,74 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
if params.Headers != nil {
req.Header = params.Headers
}
if len(logBody) > 0 {
if params.RequestJSON != nil {
req.Header.Set("Content-Type", "application/json")
}
if params.RequestLength > 0 && params.RequestBody != nil {
req.ContentLength = params.RequestLength
}
return req, nil
}

// MakeFullRequest makes a JSON HTTP request to the given URL.
// If "resBody" is not nil, the response body will be json.Unmarshalled into it.
//
// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along
// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned
// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError.
func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
if params.MaxAttempts == 0 {
params.MaxAttempts = 1 + cli.DefaultHTTPRetries
}
req, err := params.compileRequest()
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", cli.UserAgent)
if len(cli.AccessToken) > 0 {
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
}
cli.LogRequest(req, logBody)
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON)
}

func (cli *Client) logWarning(format string, args ...interface{}) {
warnLogger, ok := cli.Logger.(WarnLogger)
if ok {
warnLogger.Warnfln(format, args...)
} else {
cli.Logger.Debugfln(format, args...)
}
}

func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}) ([]byte, error) {
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
if req.Body != nil {
if req.GetBody == nil {
cli.logWarning("Failed to get new body to retry request #%d: GetBody is nil", reqID)
return nil, cause
}
var err error
req.Body, err = req.GetBody()
if err != nil {
cli.logWarning("Failed to get new body to retry request #%d: %v", reqID, err)
return nil, cause
}
}
cli.logWarning("Request #%d failed: %v, retrying in %d seconds", reqID, cause, int(backoff.Seconds()))
time.Sleep(backoff)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON)
}

func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}) ([]byte, error) {
cli.LogRequest(req)
res, err := cli.Client.Do(req)
if res != nil {
defer res.Body.Close()
}
if err != nil {
if retries > 0 {
return cli.doRetry(req, err, retries, backoff, responseJSON)
}
return nil, HTTPError{
Request: req,
Response: res,
Expand All @@ -337,6 +407,10 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
}
}

if retries > 0 && (res.StatusCode == http.StatusBadGateway || res.StatusCode == http.StatusServiceUnavailable || res.StatusCode == http.StatusGatewayTimeout) {
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON)
}

contents, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, HTTPError{
Expand All @@ -361,11 +435,11 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
}
}

if params.ResponseJSON == nil {
if responseJSON == nil {
return contents, nil
}

if err = json.Unmarshal(contents, &params.ResponseJSON); err != nil {
if err = json.Unmarshal(contents, &responseJSON); err != nil {
return nil, HTTPError{
Request: req,
Response: res,
Expand Down Expand Up @@ -411,7 +485,14 @@ func (cli *Client) SyncRequest(timeout int, since, filterID string, fullState bo
query["full_state"] = "true"
}
urlPath := cli.BuildURLWithQuery(URLPath{"sync"}, query)
_, err = cli.MakeFullRequest(FullRequest{Method: "GET", URL: urlPath, ResponseJSON: &resp, Context: ctx})
_, err = cli.MakeFullRequest(FullRequest{
Method: http.MethodGet,
URL: urlPath,
ResponseJSON: &resp,
Context: ctx,
// We don't want automatic retries for SyncRequest, the Sync() wrapper handles those.
MaxAttempts: 1,
})
return
}

Expand Down Expand Up @@ -889,7 +970,7 @@ func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string
})
}

// UploadMedia uploads the given data to the content repository and returns an MXC URI.
// Upload uploads the given data to the content repository and returns an MXC URI.
//
// Deprecated: UploadMedia should be used instead.
func (cli *Client) Upload(content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) {
Expand Down Expand Up @@ -917,63 +998,21 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) {
u.RawQuery = q.Encode()
}

req, err := http.NewRequest("POST", u.String(), data.Content)
if err != nil {
return nil, HTTPError{
WrappedError: err,
Message: "failed to create request",
}
}

var headers http.Header
if len(data.ContentType) > 0 {
req.Header.Set("Content-Type", data.ContentType)
headers = http.Header{"Content-Type": []string{data.ContentType}}
}
req.Header.Set("Authorization", "Bearer "+cli.AccessToken)
req.Header.Set("User-Agent", cli.UserAgent)
req.ContentLength = data.ContentLength

cli.LogRequest(req, fmt.Sprintf("%d bytes", data.ContentLength))

res, err := cli.Client.Do(req)
if res != nil {
defer res.Body.Close()
}
if err != nil {
return nil, err
}
contents, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, HTTPError{
Message: "failed to read upload response body",
WrappedError: err,
Response: res,
Request: req,
}
}
if res.StatusCode != 200 {
respErr := &RespError{}
if _ = json.Unmarshal(contents, respErr); respErr.ErrCode == "" {
respErr = nil
}

return nil, HTTPError{
Request: req,
Response: res,
RespError: respErr,
}
}
var m RespMediaUpload
if err := json.Unmarshal(contents, &m); err != nil {
return nil, HTTPError{
Request: req,
Response: res,

Message: "failed to unmarshal upload response body",
ResponseBody: string(contents),
WrappedError: err,
}
}
return &m, nil
_, err := cli.MakeFullRequest(FullRequest{
Method: http.MethodPost,
URL: u.String(),
Headers: headers,
RequestBody: data.Content,
RequestLength: data.ContentLength,
ResponseJSON: &m,
})
return &m, err
}

// JoinedMembers returns a map of joined room members. See https://matrix.org/docs/spec/client_server/r0.4.0.html#get-matrix-client-r0-joined-rooms
Expand Down

0 comments on commit 06fa15c

Please sign in to comment.