From 051fa433520714f83a067ff257679866d6f9c4be Mon Sep 17 00:00:00 2001 From: Nemi Shah Date: Fri, 6 Oct 2023 14:40:18 +0530 Subject: [PATCH] Check for status in github validate access token --- recipe/thirdparty/providers/github.go | 4 ++-- recipe/thirdparty/providers/oauth2_impl.go | 2 +- recipe/thirdparty/providers/twitter.go | 4 +++- recipe/thirdparty/providers/utils.go | 16 ++++++++-------- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/recipe/thirdparty/providers/github.go b/recipe/thirdparty/providers/github.go index 58aefc4f..656cfdcc 100644 --- a/recipe/thirdparty/providers/github.go +++ b/recipe/thirdparty/providers/github.go @@ -42,14 +42,14 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider { basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret)) wrongClientIdError := errors.New("Access token does not belong to your application") - resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{ + resp, status, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{ "access_token": accessToken, }, map[string]interface{}{ "Authorization": "Basic " + basicAuthToken, "Content-Type": "application/json", }) - if err != nil { + if err != nil || status != 200 { return errors.New("Invalid access token") } diff --git a/recipe/thirdparty/providers/oauth2_impl.go b/recipe/thirdparty/providers/oauth2_impl.go index 2ac9bd72..6caa8aa2 100644 --- a/recipe/thirdparty/providers/oauth2_impl.go +++ b/recipe/thirdparty/providers/oauth2_impl.go @@ -106,7 +106,7 @@ func oauth2_ExchangeAuthCodeForOAuthTokens(config tpmodels.ProviderConfigForClie } /* Transformation needed for dev keys END */ - oAuthTokens, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil) + oAuthTokens, _, err := doPostRequest(tokenAPIURL, accessTokenAPIParams, nil) if err != nil { return nil, err } diff --git a/recipe/thirdparty/providers/twitter.go b/recipe/thirdparty/providers/twitter.go index d383f410..b89589ed 100644 --- a/recipe/thirdparty/providers/twitter.go +++ b/recipe/thirdparty/providers/twitter.go @@ -87,9 +87,11 @@ func Twitter(input tpmodels.ProviderInput) *tpmodels.TypeProvider { twitterOauthParams["redirect_uri"] = redirectUri twitterOauthParams["code"] = redirectURIInfo.RedirectURIQueryParams["code"] - return doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{ + resp, _, err := doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{ "Authorization": "Basic " + basicAuthToken, }) + + return resp, err } if oOverride != nil { diff --git a/recipe/thirdparty/providers/utils.go b/recipe/thirdparty/providers/utils.go index b868d5d5..aa9ac414 100644 --- a/recipe/thirdparty/providers/utils.go +++ b/recipe/thirdparty/providers/utils.go @@ -90,16 +90,16 @@ func doGetRequest(url string, queryParams map[string]interface{}, headers map[st return result, nil } -func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, error) { +func doPostRequest(url string, params map[string]interface{}, headers map[string]interface{}) (map[string]interface{}, int, error) { supertokens.LogDebugMessage(fmt.Sprintf("POST request to %s, with form fields %v and headers %v", url, params, headers)) postBody, err := qs.Marshal(params) if err != nil { - return nil, err + return nil, -1, err } req, err := http.NewRequest("POST", url, bytes.NewBuffer([]byte(postBody))) if err != nil { - return nil, err + return nil, -1, err } for key, value := range headers { req.Header.Set(key, value.(string)) @@ -110,13 +110,13 @@ func doPostRequest(url string, params map[string]interface{}, headers map[string client := &http.Client{} resp, err := client.Do(req) if err != nil { - return nil, err + return nil, resp.StatusCode, err } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, err + return nil, resp.StatusCode, err } supertokens.LogDebugMessage(fmt.Sprintf("Received response with status %d and body %s", resp.StatusCode, string(body))) @@ -124,14 +124,14 @@ func doPostRequest(url string, params map[string]interface{}, headers map[string var result map[string]interface{} err = json.Unmarshal(body, &result) if err != nil { - return nil, err + return nil, resp.StatusCode, err } if resp.StatusCode >= 300 { - return nil, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body)) + return nil, resp.StatusCode, fmt.Errorf("POST request to %s resulted in %d status with body %s", url, resp.StatusCode, string(body)) } - return result, nil + return result, resp.StatusCode, nil } // JWKS utils