Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ValidateAccessToken function to providers #376

Merged
merged 4 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.16.1] - 2023-10-03

### Changes

- Added `ValidateAccessToken` to the configuration for social login providers, this function allows you to verify the access token returned by the social provider. If you are using Github as a provider, there is a default implementation provided for this function.

## [0.16.0] - 2023-09-27

### Fixes
Expand Down
5 changes: 3 additions & 2 deletions recipe/session/accessTokenVersions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1004,11 +1004,12 @@ func TestShouldThrowWhenRefreshInLegacySessionsWithProtectedProp(t *testing.T) {
assert.True(t, cookiesAfterRefresh["frontToken"] == "remove")
}

/**
/*
*
We want to make sure that for access token claims that can be null, the SDK does not fail access token validation if the
core does not send them as part of the payload.

For this we verify that validation passes when the keys are nil, empty or a different type
# For this we verify that validation passes when the keys are nil, empty or a different type

For now this test checks for:
- antiCsrfToken
Expand Down
82 changes: 82 additions & 0 deletions recipe/thirdparty/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
package thirdparty

import (
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -697,3 +701,81 @@ func TestPassingScopesInConfigForGithub(t *testing.T) {
"scope": {"test-scope-1 test-scope-2"},
}, authParams)
}

func TestThatSignInUpFailsIfValidateAccessTokenReturnsError(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(
&tpmodels.TypeInput{
SignInAndUpFeature: tpmodels.TypeInputSignInAndUp{
Providers: []tpmodels.ProviderInput{
{
Override: func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider {
originalImplementation.ExchangeAuthCodeForOAuthTokens = func(redirectURIInfo tpmodels.TypeRedirectURIInfo, userContext supertokens.UserContext) (tpmodels.TypeOAuthTokens, error) {
return map[string]interface{}{
"access_token": "wrongaccesstoken",
"id_token": "wrongidtoken",
}, nil
}

return originalImplementation
},
Config: tpmodels.ProviderConfig{
ThirdPartyId: "custom",
Clients: []tpmodels.ProviderClientConfig{
{
ClientID: "test",
ClientSecret: "test-secret",
Scope: []string{"test-scope-1", "test-scope-2"},
},
},
ValidateAccessToken: func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error {
if accessToken == "wrongaccesstoken" {
return errors.New("Invalid access token")
}

return nil
},
},
},
},
},
},
),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)

if err != nil {
t.Error(err.Error())
}

mux := http.NewServeMux()
testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

req, err := http.NewRequest(http.MethodPost, testServer.URL+"/auth/signinup", strings.NewReader(`{"thirdPartyId": "custom", "redirectURIInfo": {"redirectURIOnProviderDashboard": "http://127.0.0.1/callback", "redirectURIQueryParams": {"code": "abcdefghj"}}}`))
if err != nil {
t.Error(err.Error())
}

res, err := http.DefaultClient.Do(req)

data2, err := io.ReadAll(res.Body)
assert.NoError(t, err)
respString := string(data2)
respString = strings.Replace(respString, "\n", "", -1)
assert.Equal(t, respString, "Invalid access token")
}
1 change: 1 addition & 0 deletions recipe/thirdparty/providers/config_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func getProviderConfigForClient(config tpmodels.ProviderConfig, clientConfig tpm
OIDCDiscoveryEndpoint: config.OIDCDiscoveryEndpoint,
UserInfoMap: config.UserInfoMap,
ValidateIdTokenPayload: config.ValidateIdTokenPayload,
ValidateAccessToken: config.ValidateAccessToken,
RequireEmail: config.RequireEmail,
GenerateFakeEmail: config.GenerateFakeEmail,
}
Expand Down
37 changes: 37 additions & 0 deletions recipe/thirdparty/providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package providers

import (
"encoding/base64"
"errors"
"fmt"

Expand All @@ -36,6 +37,42 @@ func Github(input tpmodels.ProviderInput) *tpmodels.TypeProvider {
input.Config.TokenEndpoint = "https://github.com/login/oauth/access_token"
}

if input.Config.ValidateAccessToken == nil {
input.Config.ValidateAccessToken = func(accessToken string, clientConfig tpmodels.ProviderConfigForClientType, userContext supertokens.UserContext) error {
basicAuthToken := base64.StdEncoding.EncodeToString([]byte(clientConfig.ClientID + ":" + clientConfig.ClientSecret))
wrongClientIdError := errors.New("Access token does not belong to your application")

resp, err := doPostRequest("https://api.github.com/applications/"+clientConfig.ClientID+"/token", map[string]interface{}{
"access_token": accessToken,
}, map[string]interface{}{
"Authorization": "Basic " + basicAuthToken,
"Content-Type": "application/json",
})

if err != nil {
return errors.New("Invalid access token")
}

app, appOk := resp["app"]

if !appOk {
return wrongClientIdError
}

clientId, clientIdOk := app.(map[string]interface{})["client_id"]

if !clientIdOk {
return wrongClientIdError
}

if clientId != clientConfig.ClientID {
return wrongClientIdError
}

return nil
}
}

oOverride := input.Override

input.Override = func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider {
Expand Down
7 changes: 7 additions & 0 deletions recipe/thirdparty/providers/oauth2_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ func oauth2_GetUserInfo(config tpmodels.ProviderConfigForClientType, oAuthTokens
}
}

if config.ValidateAccessToken != nil && accessTokenOk {
err := config.ValidateAccessToken(accessToken, config, userContext)
if err != nil {
return tpmodels.TypeUserInfo{}, err
}
}

if accessTokenOk && config.UserInfoEndpoint != "" {
headers := map[string]string{
"Authorization": "Bearer " + accessToken,
Expand Down
2 changes: 2 additions & 0 deletions recipe/thirdparty/tpmodels/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ type ProviderConfig struct {
RequireEmail *bool `json:"requireEmail,omitempty"`

ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"`
ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error `json:"-"`
GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string `json:"-"`
}

Expand Down Expand Up @@ -158,6 +159,7 @@ type ProviderConfigForClientType struct {
OIDCDiscoveryEndpoint string
UserInfoMap TypeUserInfoMap
ValidateIdTokenPayload func(idTokenPayload map[string]interface{}, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error
ValidateAccessToken func(accessToken string, clientConfig ProviderConfigForClientType, userContext supertokens.UserContext) error

RequireEmail *bool
GenerateFakeEmail func(thirdPartyUserId string, tenantId string, userContext supertokens.UserContext) string
Expand Down
2 changes: 1 addition & 1 deletion supertokens/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
)

// VERSION current version of the lib
const VERSION = "0.16.0"
const VERSION = "0.16.1"

var (
cdiSupported = []string{"3.0"}
Expand Down
Loading