From 41a1ca6c8d5df3458de79a3941d83a38c0e3d152 Mon Sep 17 00:00:00 2001 From: Nemi Shah Date: Fri, 1 Sep 2023 15:23:34 +0530 Subject: [PATCH] Add retry logic to querier --- .circleci/config.yml | 12 +- .circleci/setupAndTestWithAuthReact.sh | 1 - .gitignore | 4 +- CHANGELOG.md | 4 + addDevTag | 7 - recipe/session/querier_test.go | 285 +++++++++++++++++++++++++ supertokens/constants.go | 4 +- supertokens/querier.go | 54 ++++- test/auth-react-server/main.go | 37 ++++ test/frontendIntegration/main.go | 67 ++++++ 10 files changed, 455 insertions(+), 20 deletions(-) create mode 100644 recipe/session/querier_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 876ef3dc..32bce877 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -44,8 +44,16 @@ jobs: steps: - checkout - run: apt-get install lsof - - run: curl -fsSL https://deb.nodesource.com/setup_16.x | bash - - run: apt install -y nodejs + - run: curl https://raw.githubusercontent.com/creationix/nvm/master/install.sh | bash + - run: | + set +e + export NVM_DIR="$HOME/.nvm" + [ -s "$NVM_DIR/nvm.sh" ] && \. "$NVM_DIR/nvm.sh" + [ -s "$NVM_DIR/bash_completion" ] && \. "$NVM_DIR/bash_completion" + nvm install 16 + + echo 'export NVM_DIR="$HOME/.nvm"' >> $BASH_ENV + echo '[ -s "$NVM_DIR/nvm.sh" ] && \. "$NVM_DIR/nvm.sh"' >> $BASH_ENV - run: node --version - run: echo "127.0.0.1 localhost.org" >> /etc/hosts - run: go version diff --git a/.circleci/setupAndTestWithAuthReact.sh b/.circleci/setupAndTestWithAuthReact.sh index a028f45a..ee139d26 100755 --- a/.circleci/setupAndTestWithAuthReact.sh +++ b/.circleci/setupAndTestWithAuthReact.sh @@ -49,7 +49,6 @@ git clone git@github.com:supertokens/supertokens-auth-react.git cd supertokens-auth-react git checkout $2 npm run init -(cd ./examples/for-tests && npm run link) # this is there because in linux machine, postinstall in npm doesn't work.. cd ./test/server/ npm i -d npm i git+https://github.com:supertokens/supertokens-node.git#$3 diff --git a/.gitignore b/.gitignore index 8ac48b7e..302a998b 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,6 @@ gin/ apiPassword releasePassword -.vscode/ \ No newline at end of file +.vscode/ + +.idea/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d383b75..de5dd569 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.8.4] - 2022-08-31 + +- Adds logic to retry network calls if the core returns status 429 + ## [0.8.3] - 2022-07-30 ### Added - Adds test to verify that session container uses overridden functions diff --git a/addDevTag b/addDevTag index 871e9c94..3e68c4e3 100755 --- a/addDevTag +++ b/addDevTag @@ -1,11 +1,4 @@ #!/bin/bash - -# check if we need to merge master into this branch------------ -if [[ $(git log origin/master ^HEAD) ]]; then - echo "You need to merge master into this branch. Exiting" - exit 1 -fi - # get version------------ version=`cat ./supertokens/constants.go | grep -e 'const VERSION'` while IFS='"' read -ra ADDR; do diff --git a/recipe/session/querier_test.go b/recipe/session/querier_test.go new file mode 100644 index 00000000..e0c74754 --- /dev/null +++ b/recipe/session/querier_test.go @@ -0,0 +1,285 @@ +package session + +import ( + "encoding/json" + "errors" + "github.com/stretchr/testify/assert" + "github.com/supertokens/supertokens-golang/supertokens" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" +) + +func resetQuerier() { + supertokens.SetQuerierApiVersionForTests("") +} + +func TestThatNetworkCallIsRetried(t *testing.T) { + resetAll() + mux := http.NewServeMux() + + numberOfTimesCalled := 0 + numberOfTimesSecondCalled := 0 + numberOfTimesThirdCalled := 0 + + mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) { + numberOfTimesCalled++ + rw.WriteHeader(supertokens.RateLimitStatusCode) + rw.Header().Set("Content-Type", "application/json") + response, err := json.Marshal(map[string]interface{}{}) + if err != nil { + t.Error(err.Error()) + } + rw.Write(response) + }) + + mux.HandleFunc("/testing2", func(rw http.ResponseWriter, r *http.Request) { + numberOfTimesSecondCalled++ + rw.Header().Set("Content-Type", "application/json") + + if numberOfTimesSecondCalled == 3 { + rw.WriteHeader(200) + } else { + rw.WriteHeader(supertokens.RateLimitStatusCode) + } + + response, err := json.Marshal(map[string]interface{}{}) + if err != nil { + t.Error(err.Error()) + } + rw.Write(response) + }) + + mux.HandleFunc("/testing3", func(rw http.ResponseWriter, r *http.Request) { + numberOfTimesThirdCalled++ + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(200) + response, err := json.Marshal(map[string]interface{}{}) + if err != nil { + t.Error(err.Error()) + } + rw.Write(response) + }) + + testServer := httptest.NewServer(mux) + + defer func() { + testServer.Close() + }() + + config := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + // We need the querier to call the test server and not the core + ConnectionURI: testServer.URL, + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + + err := supertokens.Init(config) + + if err != nil { + t.Error(err.Error()) + } + + q, err := supertokens.GetNewQuerierInstanceOrThrowError("") + supertokens.SetQuerierApiVersionForTests("3.0") + defer resetQuerier() + + if err != nil { + t.Error(err.Error()) + } + + _, err = q.SendGetRequest("/testing", map[string]string{}) + if err == nil { + t.Error(errors.New("request should have failed but didnt").Error()) + } else { + if !strings.Contains(err.Error(), "with status code: 429") { + t.Error(errors.New("request failed with an unexpected error").Error()) + } + } + + _, err = q.SendGetRequest("/testing2", map[string]string{}) + if err != nil { + t.Error(err.Error()) + } + + _, err = q.SendGetRequest("/testing3", map[string]string{}) + if err != nil { + t.Error(err.Error()) + } + + // One initial call + 5 retries + assert.Equal(t, numberOfTimesCalled, 6) + assert.Equal(t, numberOfTimesSecondCalled, 3) + assert.Equal(t, numberOfTimesThirdCalled, 1) +} + +func TestThatRateLimitErrorsAreThrownBackToTheUser(t *testing.T) { + resetAll() + mux := http.NewServeMux() + + mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(supertokens.RateLimitStatusCode) + rw.Header().Set("Content-Type", "application/json") + response, err := json.Marshal(map[string]interface{}{ + "status": "RATE_LIMIT_ERROR", + }) + if err != nil { + t.Error(err.Error()) + } + rw.Write(response) + }) + + testServer := httptest.NewServer(mux) + + defer func() { + testServer.Close() + }() + + config := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + // We need the querier to call the test server and not the core + ConnectionURI: testServer.URL, + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + + err := supertokens.Init(config) + + if err != nil { + t.Error(err.Error()) + } + + q, err := supertokens.GetNewQuerierInstanceOrThrowError("") + supertokens.SetQuerierApiVersionForTests("3.0") + defer resetQuerier() + + if err != nil { + t.Error(err.Error()) + } + + _, err = q.SendGetRequest("/testing", map[string]string{}) + if err == nil { + t.Error(errors.New("request should have failed but didnt").Error()) + } else { + if !strings.Contains(err.Error(), "with status code: 429") { + t.Error(errors.New("request failed with an unexpected error").Error()) + } + + assert.True(t, strings.Contains(err.Error(), "message: {\"status\":\"RATE_LIMIT_ERROR\"}")) + } +} + +func TestThatParallelCallsHaveIndependentRetryCounters(t *testing.T) { + resetAll() + mux := http.NewServeMux() + + numberOfTimesFirstCalled := 0 + numberOfTimesSecondCalled := 0 + + mux.HandleFunc("/testing", func(rw http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("id") == "1" { + numberOfTimesFirstCalled++ + } else { + numberOfTimesSecondCalled++ + } + + rw.WriteHeader(supertokens.RateLimitStatusCode) + rw.Header().Set("Content-Type", "application/json") + response, err := json.Marshal(map[string]interface{}{}) + if err != nil { + t.Error(err.Error()) + } + rw.Write(response) + }) + + testServer := httptest.NewServer(mux) + + defer func() { + testServer.Close() + }() + + config := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + // We need the querier to call the test server and not the core + ConnectionURI: testServer.URL, + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + + err := supertokens.Init(config) + + if err != nil { + t.Error(err.Error()) + } + + q, err := supertokens.GetNewQuerierInstanceOrThrowError("") + supertokens.SetQuerierApiVersionForTests("3.0") + defer resetQuerier() + + if err != nil { + t.Error(err.Error()) + } + + var wg sync.WaitGroup + + wg.Add(2) + + go func() { + _, err = q.SendGetRequest("/testing", map[string]string{ + "id": "1", + }) + if err == nil { + t.Error(errors.New("request should have failed but didnt").Error()) + } else { + if !strings.Contains(err.Error(), "with status code: 429") { + t.Error(errors.New("request failed with an unexpected error").Error()) + } + } + + wg.Done() + }() + + go func() { + _, err = q.SendGetRequest("/testing", map[string]string{ + "id": "2", + }) + if err == nil { + t.Error(errors.New("request should have failed but didnt").Error()) + } else { + if !strings.Contains(err.Error(), "with status code: 429") { + t.Error(errors.New("request failed with an unexpected error").Error()) + } + } + + wg.Done() + }() + + wg.Wait() + + assert.Equal(t, numberOfTimesFirstCalled, 6) + assert.Equal(t, numberOfTimesSecondCalled, 6) +} diff --git a/supertokens/constants.go b/supertokens/constants.go index ba983d2e..37a73a7e 100644 --- a/supertokens/constants.go +++ b/supertokens/constants.go @@ -21,8 +21,10 @@ const ( ) // VERSION current version of the lib -const VERSION = "0.8.3" +const VERSION = "0.8.4" var ( cdiSupported = []string{"2.8", "2.9", "2.10", "2.11", "2.12", "2.13", "2.14", "2.15"} ) + +const RateLimitStatusCode = 429 diff --git a/supertokens/querier.go b/supertokens/querier.go index fd3ef61b..a78cabed 100644 --- a/supertokens/querier.go +++ b/supertokens/querier.go @@ -24,6 +24,7 @@ import ( "net/http" "strings" "sync" + "time" ) type Querier struct { @@ -45,6 +46,10 @@ var ( querierHostLock sync.Mutex ) +func SetQuerierApiVersionForTests(version string) { + querierAPIVersion = version +} + func (q *Querier) GetQuerierAPIVersion() (string, error) { querierLock.Lock() defer querierLock.Unlock() @@ -61,7 +66,7 @@ func (q *Querier) GetQuerierAPIVersion() (string, error) { } client := &http.Client{} return client.Do(req) - }, len(QuerierHosts)) + }, len(QuerierHosts), nil) if err != nil { return "", err @@ -141,7 +146,7 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map client := &http.Client{} return client.Do(req) - }, len(QuerierHosts)) + }, len(QuerierHosts), nil) } func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}) (map[string]interface{}, error) { @@ -175,7 +180,7 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}) (m client := &http.Client{} return client.Do(req) - }, len(QuerierHosts)) + }, len(QuerierHosts), nil) } func (q *Querier) SendGetRequest(path string, params map[string]string) (map[string]interface{}, error) { @@ -210,7 +215,7 @@ func (q *Querier) SendGetRequest(path string, params map[string]string) (map[str client := &http.Client{} return client.Do(req) - }, len(QuerierHosts)) + }, len(QuerierHosts), nil) } func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[string]interface{}, error) { @@ -244,12 +249,12 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[ client := &http.Client{} return client.Do(req) - }, len(QuerierHosts)) + }, len(QuerierHosts), nil) } type httpRequestFunction func(url string) (*http.Response, error) -func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int) (map[string]interface{}, error) { +func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, error) { if numberOfTries == 0 { return nil, errors.New("no SuperTokens core available to query") } @@ -257,14 +262,32 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ querierHostLock.Lock() currentDomain := QuerierHosts[querierLastTriedIndex].Domain.GetAsStringDangerous() currentBasePath := QuerierHosts[querierLastTriedIndex].BasePath.GetAsStringDangerous() + + url := currentDomain + currentBasePath + path.GetAsStringDangerous() + + maxRetries := 5 + var _retryInfoMap map[string]int + + if retryInfoMap != nil { + _retryInfoMap = *retryInfoMap + } else { + _retryInfoMap = map[string]int{} + } + + _, ok := _retryInfoMap[url] + + if !ok { + _retryInfoMap[url] = maxRetries + } + querierLastTriedIndex = (querierLastTriedIndex + 1) % len(QuerierHosts) querierHostLock.Unlock() - resp, err := httpRequest(currentDomain + currentBasePath + path.GetAsStringDangerous()) + resp, err := httpRequest(url) if err != nil { if strings.Contains(err.Error(), "connection refused") { - return q.sendRequestHelper(path, httpRequest, numberOfTries-1) + return q.sendRequestHelper(path, httpRequest, numberOfTries-1, &_retryInfoMap) } if resp != nil { resp.Body.Close() @@ -279,6 +302,21 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ return nil, readErr } if resp.StatusCode != 200 { + if resp.StatusCode == RateLimitStatusCode { + retriesLeft := _retryInfoMap[url] + + if retriesLeft > 0 { + _retryInfoMap[url] = retriesLeft - 1 + + attemptsMade := maxRetries - retriesLeft + delay := 10 + (250 * attemptsMade) + + time.Sleep(time.Millisecond * time.Duration(delay)) + + return q.sendRequestHelper(path, httpRequest, numberOfTries, &_retryInfoMap) + } + } + return nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body) } diff --git a/test/auth-react-server/main.go b/test/auth-react-server/main.go index 3acbfad1..765cb6fa 100644 --- a/test/auth-react-server/main.go +++ b/test/auth-react-server/main.go @@ -559,6 +559,43 @@ func callSTInit(passwordlessConfig *plessmodels.TypeInput) { "available": []string{"passwordless", "thirdpartypasswordless", "generalerror"}, }) rw.Write(bytes) + } else if r.URL.Path == "/deleteUser" { + bodyBytes, err := ioutil.ReadAll(r.Body) + if err != nil { + rw.WriteHeader(500) + rw.Write([]byte("Internal error")) + return + } + var body map[string]interface{} + err = json.Unmarshal(bodyBytes, &body) + if err != nil { + rw.WriteHeader(500) + rw.Write([]byte("Internal error")) + return + } + + if body["rid"] != "emailpassword" { + rw.WriteHeader(400) + rw.Write([]byte("{\"message\": \"Not Implemented\"}")) + return + } + + user, err := emailpassword.GetUserByEmail(body["email"].(string)) + if err != nil { + rw.WriteHeader(500) + rw.Write([]byte("Internal error")) + return + } + + err = supertokens.DeleteUser(user.ID) + if err != nil { + rw.WriteHeader(500) + rw.Write([]byte("Internal error")) + return + } + + rw.WriteHeader(200) + rw.Write([]byte("{\"status\": \"OK\"}")) } })) diff --git a/test/frontendIntegration/main.go b/test/frontendIntegration/main.go index b7839637..20bd89a9 100644 --- a/test/frontendIntegration/main.go +++ b/test/frontendIntegration/main.go @@ -17,6 +17,7 @@ package main import ( + "encoding/base64" "encoding/json" "fmt" "io/ioutil" @@ -24,6 +25,7 @@ import ( "os" "strconv" "strings" + "time" "github.com/supertokens/supertokens-golang/recipe/session" "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" @@ -170,6 +172,8 @@ func callSTInit(enableAntiCsrf bool, enableJWT bool, jwtPropertyName string) { setEnableJWT(rw, r) } else if r.URL.Path == "/login" && r.Method == "POST" { login(rw, r) + } else if r.URL.Path == "/login-2.18" && r.Method == "POST" { + login218(rw, r) } else if r.URL.Path == "/beforeeach" && r.Method == "POST" { beforeeach(rw, r) } else if r.URL.Path == "/testUserConfig" && r.Method == "POST" { @@ -398,6 +402,69 @@ func login(response http.ResponseWriter, request *http.Request) { response.Write([]byte(sess.GetUserID())) } +func login218(response http.ResponseWriter, request *http.Request) { + var body map[string]interface{} + _ = json.NewDecoder(request.Body).Decode(&body) + + userID := body["userId"].(string) + payload := body["payload"].(map[string]interface{}) + + querier, err := supertokens.GetNewQuerierInstanceOrThrowError("session") + + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + supertokens.SetQuerierApiVersionForTests("2.18") + resp, err := querier.SendPostRequest("/recipe/session", map[string]interface{}{ + "userId": userID, + "userDataInJWT": payload, + "userDataInDatabase": map[string]interface{}{}, + "enableAntiCsrf": false, + }) + + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + supertokens.SetQuerierApiVersionForTests("") + + responseByte, err := json.Marshal(resp) + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + var sessionResp sessmodels.CreateOrRefreshAPIResponse + err = json.Unmarshal(responseByte, &sessionResp) + if err != nil { + response.WriteHeader(500) + response.Write([]byte("")) + return + } + + legacyAccessToken := sessionResp.AccessToken.Token + legacyRefreshToken := sessionResp.RefreshToken.Token + + parsed, _ := json.Marshal(map[string]interface{}{ + "uid": userID, + "ate": uint64(time.Now().UnixNano()/1000000) + 3600000, + "up": payload, + }) + data := []byte(parsed) + + frontToken := base64.StdEncoding.EncodeToString(data) + + response.Header().Set("st-access-token", legacyAccessToken) + response.Header().Set("st-refresh-token", legacyRefreshToken) + response.Header().Set("front-token", frontToken) + response.Write([]byte("")) +} + func fail(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) w.Write([]byte(""))