diff --git a/CHANGELOG.md b/CHANGELOG.md index c2bcd76e..bfbc2b7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.16.1] - 2023-10-03 + +### Changes + +- Added `ValidateAccessToken` to the configuration for social login providers, this function allows you to verify the access token returned by the social provider. If you are using Github as a provider, there is a default implementation provided for this function. + ## [0.16.0] - 2023-09-27 ### Fixes diff --git a/recipe/session/accessTokenVersions_test.go b/recipe/session/accessTokenVersions_test.go index d24970d3..da06b1e1 100644 --- a/recipe/session/accessTokenVersions_test.go +++ b/recipe/session/accessTokenVersions_test.go @@ -1004,11 +1004,12 @@ func TestShouldThrowWhenRefreshInLegacySessionsWithProtectedProp(t *testing.T) { assert.True(t, cookiesAfterRefresh["frontToken"] == "remove") } -/** +/* +* We want to make sure that for access token claims that can be null, the SDK does not fail access token validation if the core does not send them as part of the payload. -For this we verify that validation passes when the keys are nil, empty or a different type +# For this we verify that validation passes when the keys are nil, empty or a different type For now this test checks for: - antiCsrfToken diff --git a/recipe/thirdparty/provider_test.go b/recipe/thirdparty/provider_test.go index 0a8151be..4c9c85be 100644 --- a/recipe/thirdparty/provider_test.go +++ b/recipe/thirdparty/provider_test.go @@ -17,9 +17,13 @@ package thirdparty import ( + "errors" + "io" "io/ioutil" "net/http" + "net/http/httptest" "net/url" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -697,3 +701,81 @@ func TestPassingScopesInConfigForGithub(t *testing.T) { "scope": {"test-scope-1 test-scope-2"}, }, authParams) } + +func TestThatSignInUpFailsIfValidateAccessTokenReturnsError(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + APIDomain: "api.supertokens.io", + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init( + &tpmodels.TypeInput{ + SignInAndUpFeature: tpmodels.TypeInputSignInAndUp{ + Providers: []tpmodels.ProviderInput{ + { + Override: func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider { + originalImplementation.ExchangeAuthCodeForOAuthTokens = func(redirectURIInfo tpmodels.TypeRedirectURIInfo, userContext supertokens.UserContext) (tpmodels.TypeOAuthTokens, error) { + return map[string]interface{}{ + "access_token": "wrongaccesstoken", + "id_token": "wrongidtoken", + }, nil + } + + return originalImplementation + }, + Config: tpmodels.ProviderConfig{ + ThirdPartyId: "custom", + Clients: []tpmodels.ProviderClientConfig{ + { + ClientID: "test", + ClientSecret: "test-secret", + Scope: []string{"test-scope-1", "test-scope-2"}, + }, + }, + ValidateAccessToken: func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error { + if accessToken == "wrongaccesstoken" { + return errors.New("Invalid access token") + } + + return nil + }, + }, + }, + }, + }, + }, + ), + }, + } + + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + + if err != nil { + t.Error(err.Error()) + } + + mux := http.NewServeMux() + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer testServer.Close() + + req, err := http.NewRequest(http.MethodPost, testServer.URL+"/auth/signinup", strings.NewReader(`{"thirdPartyId": "custom", "redirectURIInfo": {"redirectURIOnProviderDashboard": "http://127.0.0.1/callback", "redirectURIQueryParams": {"code": "abcdefghj"}}}`)) + if err != nil { + t.Error(err.Error()) + } + + res, err := http.DefaultClient.Do(req) + + data2, err := io.ReadAll(res.Body) + assert.NoError(t, err) + respString := string(data2) + respString = strings.Replace(respString, "\n", "", -1) + assert.Equal(t, respString, "Invalid access token") +} diff --git a/recipe/thirdparty/providers/config_utils.go b/recipe/thirdparty/providers/config_utils.go index b40cf462..845b6e9b 100644 --- a/recipe/thirdparty/providers/config_utils.go +++ b/recipe/thirdparty/providers/config_utils.go @@ -28,6 +28,7 @@ func getProviderConfigForClient(config tpmodels.ProviderConfig, clientConfig tpm OIDCDiscoveryEndpoint: config.OIDCDiscoveryEndpoint, UserInfoMap: config.UserInfoMap, ValidateIdTokenPayload: config.ValidateIdTokenPayload, + ValidateAccessToken: config.ValidateAccessToken, RequireEmail: config.RequireEmail, GenerateFakeEmail: config.GenerateFakeEmail, } diff --git a/recipe/thirdparty/providers/github.go b/recipe/thirdparty/providers/github.go index 01bd827b..58aefc4f 100644 --- a/recipe/thirdparty/providers/github.go +++ b/recipe/thirdparty/providers/github.go @@ -16,6 +16,7 @@ package providers import ( + "encoding/base64" "errors" "fmt" @@ -36,6 +37,42 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider { input.Config.TokenEndpoint = "https://github.com/login/oauth/access_token" } + if input.Config.ValidateAccessToken == nil { + input.Config.ValidateAccessToken = func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error { + basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret)) + wrongClientIdError := errors.New("Access token does not belong to your application") + + resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{ + "access_token": accessToken, + }, map[string]interface{}{ + "Authorization": "Basic " + basicAuthToken, + "Content-Type": "application/json", + }) + + if err != nil { + return errors.New("Invalid access token") + } + + app, appOk := resp["app"] + + if !appOk { + return wrongClientIdError + } + + clientId, clientIdOk := app.(map[string]interface{})["client_id"] + + if !clientIdOk { + return wrongClientIdError + } + + if clientId != clientConfig.ClientID { + return wrongClientIdError + } + + return nil + } + } + oOverride := input.Override input.Override = func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider { diff --git a/recipe/thirdparty/providers/oauth2_impl.go b/recipe/thirdparty/providers/oauth2_impl.go index aa430b71..2ac9bd72 100644 --- a/recipe/thirdparty/providers/oauth2_impl.go +++ b/recipe/thirdparty/providers/oauth2_impl.go @@ -144,6 +144,13 @@ func oauth2_GetUserInfo(config tpmodels.ProviderConfigForClientType, oAuthTokens } } + if config.ValidateAccessToken != nil && accessTokenOk { + err := config.ValidateAccessToken(accessToken, config, userContext) + if err != nil { + return tpmodels.TypeUserInfo{}, err + } + } + if accessTokenOk && config.UserInfoEndpoint != "" { headers := map[string]string{ "Authorization": "Bearer " + accessToken, diff --git a/recipe/thirdparty/tpmodels/models.go b/recipe/thirdparty/tpmodels/models.go index b9140101..74922394 100644 --- a/recipe/thirdparty/tpmodels/models.go +++ b/recipe/thirdparty/tpmodels/models.go @@ -126,6 +126,7 @@ type ProviderConfig struct { RequireEmail *bool `json:"requireEmail,omitempty"` ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"` + ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"` GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string `json:"-"` } @@ -158,6 +159,7 @@ type ProviderConfigForClientType struct { OIDCDiscoveryEndpoint string UserInfoMap TypeUserInfoMap ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error + ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error RequireEmail *bool GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string diff --git a/supertokens/constants.go b/supertokens/constants.go index 756eceec..69416d4d 100644 --- a/supertokens/constants.go +++ b/supertokens/constants.go @@ -21,7 +21,7 @@ const ( ) // VERSION current version of the lib -const VERSION = "0.16.0" +const VERSION = "0.16.1" var ( cdiSupported = []string{"3.0"}