diff --git a/CHANGELOG.md b/CHANGELOG.md index 5941e1ee..2a8bcd92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] + +## [0.17.1] - 2023-11-24 + +### Added + +- Adds support for configuring multiple frontend domains to be used with the same backend +- Added new `Origin` and `GetOrigin` properties to `AppInfo`, this can be configured to allow you to conditionally return the value of the frontend domain. This property will replace `WebsiteDomain` in a future release of `supertokens-golang` +- `WebsiteDomain` inside `AppInfo` is now optional. Using `Origin` or `GetOrigin` is recommended over using `WebsiteDomain`. This is not a breaking change and using `WebsiteDomain` will continue to work. + ## [0.17.0] - 2023-11-14 ### Breaking change diff --git a/recipe/dashboard/api/analyticsPOST.go b/recipe/dashboard/api/analyticsPOST.go index d96ea8b0..fbb2486f 100644 --- a/recipe/dashboard/api/analyticsPOST.go +++ b/recipe/dashboard/api/analyticsPOST.go @@ -61,8 +61,13 @@ func AnalyticsPost(apiInterface dashboardmodels.APIInterface, tenantId string, o } } + websiteDomain, err := supertokensInstance.AppInfo.GetOrigin(nil, &map[string]interface{}{}) + if err != nil { + return analyticsPostResponse{}, err + } + data := map[string]interface{}{ - "websiteDomain": supertokensInstance.AppInfo.WebsiteDomain.GetAsStringDangerous(), + "websiteDomain": websiteDomain.GetAsStringDangerous(), "apiDomain": supertokensInstance.AppInfo.APIDomain.GetAsStringDangerous(), "appName": supertokensInstance.AppInfo.AppName, "sdk": "golang", diff --git a/recipe/dashboard/recipe.go b/recipe/dashboard/recipe.go index 135ba154..d5502461 100644 --- a/recipe/dashboard/recipe.go +++ b/recipe/dashboard/recipe.go @@ -350,7 +350,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/emailpassword/api/implementation.go b/recipe/emailpassword/api/implementation.go index 6bf541ec..209fd5dd 100644 --- a/recipe/emailpassword/api/implementation.go +++ b/recipe/emailpassword/api/implementation.go @@ -66,11 +66,13 @@ func MakeAPIImplementation() epmodels.APIInterface { }, nil } - passwordResetLink := GetPasswordResetLink( + passwordResetLink, err := GetPasswordResetLink( options.AppInfo, options.RecipeID, response.OK.Token, tenantId, + options.Req, + userContext, ) if err != nil { diff --git a/recipe/emailpassword/api/utils.go b/recipe/emailpassword/api/utils.go index 74a7a7d2..fa6db009 100644 --- a/recipe/emailpassword/api/utils.go +++ b/recipe/emailpassword/api/utils.go @@ -18,6 +18,7 @@ package api import ( "encoding/json" "fmt" + "net/http" "strings" "github.com/supertokens/supertokens-golang/recipe/emailpassword/epmodels" @@ -125,13 +126,17 @@ func validateFormOrThrowError(configFormFields []epmodels.NormalisedFormField, i return nil } -func GetPasswordResetLink(appInfo supertokens.NormalisedAppinfo, recipeID string, token string, tenantId string) string { +func GetPasswordResetLink(appInfo supertokens.NormalisedAppinfo, recipeID string, token string, tenantId string, request *http.Request, userContext supertokens.UserContext) (string, error) { + websiteDomain, err := appInfo.GetOrigin(request, userContext) + if err != nil { + return "", err + } return fmt.Sprintf( "%s%s/reset-password?token=%s&rid=%s&tenantId=%s", - appInfo.WebsiteDomain.GetAsStringDangerous(), + websiteDomain.GetAsStringDangerous(), appInfo.WebsiteBasePath.GetAsStringDangerous(), token, recipeID, tenantId, - ) + ), nil } diff --git a/recipe/emailpassword/main.go b/recipe/emailpassword/main.go index f18d2c32..e157d16a 100644 --- a/recipe/emailpassword/main.go +++ b/recipe/emailpassword/main.go @@ -138,14 +138,22 @@ func CreateResetPasswordLink(tenantId string, userID string, userContext ...supe return epmodels.CreateResetPasswordLinkResponse{}, err } + link, err := api.GetPasswordResetLink( + instance.RecipeModule.GetAppInfo(), + instance.RecipeModule.GetRecipeID(), + tokenResponse.OK.Token, + tenantId, + supertokens.GetRequestFromUserContext(userContext[0]), + userContext[0], + ) + + if err != nil { + return epmodels.CreateResetPasswordLinkResponse{}, err + } + return epmodels.CreateResetPasswordLinkResponse{ OK: &struct{ Link string }{ - Link: api.GetPasswordResetLink( - instance.RecipeModule.GetAppInfo(), - instance.RecipeModule.GetRecipeID(), - tokenResponse.OK.Token, - tenantId, - ), + Link: link, }, }, nil } diff --git a/recipe/emailpassword/passwordReset_test.go b/recipe/emailpassword/passwordReset_test.go index f64454c9..45d7d39d 100644 --- a/recipe/emailpassword/passwordReset_test.go +++ b/recipe/emailpassword/passwordReset_test.go @@ -455,3 +455,106 @@ func TestValidTokenInputAndPasswordHasChanged(t *testing.T) { assert.Equal(t, userInfo["id"], result3["user"].(map[string]interface{})["id"].(string)) assert.Equal(t, userInfo["email"], result3["user"].(map[string]interface{})["email"].(string)) } + +func TestPasswordResetLinkUsesOriginFunctionIfProvided(t *testing.T) { + resetURL := "" + tokenInfo := "" + ridInfo := "" + sendEmailFunc := func(input emaildelivery.EmailType, userContext supertokens.UserContext) error { + u, err := url.Parse(input.PasswordReset.PasswordResetLink) + if err != nil { + return err + } + resetURL = u.Scheme + "://" + u.Host + u.Path + tokenInfo = u.Query().Get("token") + ridInfo = u.Query().Get("rid") + return nil + } + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + APIDomain: "api.supertokens.io", + AppName: "SuperTokens", + GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + // read request body + decoder := json.NewDecoder(request.Body) + var requestBody map[string]interface{} + err := decoder.Decode(&requestBody) + if err != nil { + return "https://supertokens.com", nil + } + if requestBody["origin"] == nil { + return "https://supertokens.com", nil + } + return requestBody["origin"].(string), nil + }, + }, + RecipeList: []supertokens.Recipe{ + Init(&epmodels.TypeInput{ + EmailDelivery: &emaildelivery.TypeInput{ + Service: &emaildelivery.EmailDeliveryInterface{ + SendEmail: &sendEmailFunc, + }, + }, + }), + session.Init(nil), + }, + } + + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + mux := http.NewServeMux() + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer testServer.Close() + + res, err := unittesting.SignupRequest("random@gmail.com", "validpass123", testServer.URL) + if err != nil { + t.Error(err.Error()) + } + assert.NoError(t, err) + dataInBytes, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err.Error()) + } + res.Body.Close() + var result map[string]interface{} + err = json.Unmarshal(dataInBytes, &result) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, 200, res.StatusCode) + assert.Equal(t, "OK", result["status"]) + + formFields := map[string]interface{}{ + "origin": "localhost:2000", + "formFields": []map[string]interface{}{{ + "id": "email", + "value": "random@gmail.com", + }}, + } + + postBody, err := json.Marshal(formFields) + if err != nil { + t.Error(err.Error()) + } + + resp, err := http.Post(testServer.URL+"/auth/user/password/reset/token", "application/json", bytes.NewBuffer(postBody)) + + if err != nil { + t.Error(err.Error()) + } + + assert.NoError(t, err) + + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "http://localhost:2000/auth/reset-password", resetURL) + assert.NotEmpty(t, tokenInfo) + assert.True(t, strings.HasPrefix(ridInfo, "emailpassword")) +} diff --git a/recipe/emailpassword/recipe.go b/recipe/emailpassword/recipe.go index f37f9d19..0d501ccc 100644 --- a/recipe/emailpassword/recipe.go +++ b/recipe/emailpassword/recipe.go @@ -181,7 +181,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { if defaultErrors.As(err, &errors.FieldError{}) { errs := err.(errors.FieldError) return true, supertokens.Send200Response(res, map[string]interface{}{ diff --git a/recipe/emailverification/api/implementation.go b/recipe/emailverification/api/implementation.go index 21cb6888..5c144b13 100644 --- a/recipe/emailverification/api/implementation.go +++ b/recipe/emailverification/api/implementation.go @@ -122,13 +122,20 @@ func MakeAPIImplementation() evmodels.APIInterface { ID: userID, Email: email.OK.Email, } - emailVerificationURL := GetEmailVerifyLink( + + emailVerificationURL, err := GetEmailVerifyLink( options.AppInfo, response.OK.Token, options.RecipeID, sessionContainer.GetTenantIdWithContext(userContext), + options.Req, + userContext, ) + if err != nil { + return evmodels.GenerateEmailVerifyTokenPOSTResponse{}, err + } + supertokens.LogDebugMessage(fmt.Sprintf("Sending email verification email to %s", email.OK.Email)) err = (*options.EmailDelivery.IngredientInterfaceImpl.SendEmail)(emaildelivery.EmailType{ EmailVerification: &emaildelivery.EmailVerificationType{ diff --git a/recipe/emailverification/api/utils.go b/recipe/emailverification/api/utils.go index f759c515..b18490c4 100644 --- a/recipe/emailverification/api/utils.go +++ b/recipe/emailverification/api/utils.go @@ -2,17 +2,22 @@ package api import ( "fmt" + "net/http" "github.com/supertokens/supertokens-golang/supertokens" ) -func GetEmailVerifyLink(appInfo supertokens.NormalisedAppinfo, token string, recipeID string, tenantId string) string { +func GetEmailVerifyLink(appInfo supertokens.NormalisedAppinfo, token string, recipeID string, tenantId string, request *http.Request, userContext supertokens.UserContext) (string, error) { + websiteDomain, err := appInfo.GetOrigin(request, userContext) + if err != nil { + return "", err + } return fmt.Sprintf( "%s%s/verify-email?token=%s&rid=%s&tenantId=%s", - appInfo.WebsiteDomain.GetAsStringDangerous(), + websiteDomain.GetAsStringDangerous(), appInfo.WebsiteBasePath.GetAsStringDangerous(), token, recipeID, tenantId, - ) + ), nil } diff --git a/recipe/emailverification/emailverification_email_test.go b/recipe/emailverification/emailverification_email_test.go index 4f3b1004..ae210578 100644 --- a/recipe/emailverification/emailverification_email_test.go +++ b/recipe/emailverification/emailverification_email_test.go @@ -26,6 +26,7 @@ import ( "github.com/supertokens/supertokens-golang/recipe/session" "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" "github.com/supertokens/supertokens-golang/supertokens" + "github.com/supertokens/supertokens-golang/test/unittesting" ) func TestBackwardCompatibilityServiceWithoutCustomFunction(t *testing.T) { @@ -286,6 +287,60 @@ func TestSMTPServiceOverrideDefaultEmailTemplate(t *testing.T) { assert.Equal(t, sendRawEmailCalled, true) } +func TestThatLinkUsesResultFromOriginFunction(t *testing.T) { + link := "" + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + APIDomain: "api.supertokens.io", + AppName: "SuperTokens", + GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + return (*userContext)["link"].(string), nil + }, + }, + RecipeList: []supertokens.Recipe{ + Init(evmodels.TypeInput{ + Mode: "OPTIONAL", + EmailDelivery: &emaildelivery.TypeInput{ + Override: func(originalImplementation emaildelivery.EmailDeliveryInterface) emaildelivery.EmailDeliveryInterface { + (*originalImplementation.SendEmail) = func(input emaildelivery.EmailType, userContext supertokens.UserContext) error { + link = input.EmailVerification.EmailVerifyLink + return nil + } + return originalImplementation + }, + }, + }), + session.Init(nil), + }, + } + + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + + email := "test@exmaple.com" + resp, err := SendEmailVerificationEmail("public", "userId", &email, &map[string]interface{}{ + "link": "localhost:8080", + }) + if err != nil { + t.Error(err.Error()) + } + assert.True(t, resp.OK != nil) + + assert.Equal(t, EmailVerificationEmailSentForTest, false) + // assert that link starts with http://localhost:8080. We use starts with because the link + // can continue a path and random query params too + assert.Equal(t, link[:21], "http://localhost:8080") + +} + // func TestSMTPServiceManually(t *testing.T) { // targetEmail := "..." // fromEmail := "no-reply@supertokens.com" diff --git a/recipe/emailverification/main.go b/recipe/emailverification/main.go index 541d77b6..61b68c0a 100644 --- a/recipe/emailverification/main.go +++ b/recipe/emailverification/main.go @@ -176,9 +176,15 @@ func CreateEmailVerificationLink(tenantId string, userID string, email *string, }, nil } + link, err := api.GetEmailVerifyLink(st.AppInfo, emailVerificationTokenResponse.OK.Token, instance.RecipeModule.GetRecipeID(), tenantId, supertokens.GetRequestFromUserContext(userContext[0]), userContext[0]) + + if err != nil { + return evmodels.CreateEmailVerificationLinkResponse{}, err + } + return evmodels.CreateEmailVerificationLinkResponse{ OK: &struct{ Link string }{ - Link: api.GetEmailVerifyLink(st.AppInfo, emailVerificationTokenResponse.OK.Token, instance.RecipeModule.GetRecipeID(), tenantId), + Link: link, }, }, nil } diff --git a/recipe/emailverification/recipe.go b/recipe/emailverification/recipe.go index e9889302..f1c02bb6 100644 --- a/recipe/emailverification/recipe.go +++ b/recipe/emailverification/recipe.go @@ -199,7 +199,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/jwt/recipe.go b/recipe/jwt/recipe.go index af7a6053..a1668da0 100644 --- a/recipe/jwt/recipe.go +++ b/recipe/jwt/recipe.go @@ -107,7 +107,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/multitenancy/recipe.go b/recipe/multitenancy/recipe.go index 74fbd89d..7c514829 100644 --- a/recipe/multitenancy/recipe.go +++ b/recipe/multitenancy/recipe.go @@ -144,7 +144,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/openid/recipe.go b/recipe/openid/recipe.go index 1594b305..c87632cd 100644 --- a/recipe/openid/recipe.go +++ b/recipe/openid/recipe.go @@ -127,8 +127,8 @@ func (r *Recipe) getAllCORSHeaders() []string { return r.JwtRecipe.RecipeModule.GetAllCORSHeaders() } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - return r.JwtRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + return r.JwtRecipe.RecipeModule.HandleError(err, req, res, userContext) } func ResetForTest() { diff --git a/recipe/passwordless/api/implementation.go b/recipe/passwordless/api/implementation.go index fb18c9e2..8cf22001 100644 --- a/recipe/passwordless/api/implementation.go +++ b/recipe/passwordless/api/implementation.go @@ -99,13 +99,18 @@ func MakeAPIImplementation() plessmodels.APIInterface { var userInputCode *string flowType := options.Config.FlowType if flowType == "MAGIC_LINK" || flowType == "USER_INPUT_CODE_AND_MAGIC_LINK" { - link := GetMagicLink( + link, err := GetMagicLink( options.AppInfo, options.RecipeID, response.OK.PreAuthSessionID, response.OK.LinkCode, tenantId, + options.Req, + userContext, ) + if err != nil { + return plessmodels.CreateCodePOSTResponse{}, err + } magicLink = &link } @@ -276,13 +281,18 @@ func MakeAPIImplementation() plessmodels.APIInterface { var userInputCode *string flowType := options.Config.FlowType if flowType == "MAGIC_LINK" || flowType == "USER_INPUT_CODE_AND_MAGIC_LINK" { - link := GetMagicLink( + link, err := GetMagicLink( options.AppInfo, options.RecipeID, response.OK.PreAuthSessionID, response.OK.LinkCode, tenantId, + options.Req, + userContext, ) + if err != nil { + return plessmodels.ResendCodePOSTResponse{}, err + } magicLink = &link } diff --git a/recipe/passwordless/api/utils.go b/recipe/passwordless/api/utils.go index a9068222..95d002dd 100644 --- a/recipe/passwordless/api/utils.go +++ b/recipe/passwordless/api/utils.go @@ -2,18 +2,23 @@ package api import ( "fmt" + "net/http" "github.com/supertokens/supertokens-golang/supertokens" ) -func GetMagicLink(appInfo supertokens.NormalisedAppinfo, recipeID string, preAuthSessionID string, linkCode string, tenantId string) string { +func GetMagicLink(appInfo supertokens.NormalisedAppinfo, recipeID string, preAuthSessionID string, linkCode string, tenantId string, request *http.Request, userContext supertokens.UserContext) (string, error) { + websiteDomain, err := appInfo.GetOrigin(request, userContext) + if err != nil { + return "", err + } return fmt.Sprintf( "%s%s/verify?rid=%s&preAuthSessionId=%s&tenantId=%s#%s", - appInfo.WebsiteDomain.GetAsStringDangerous(), + websiteDomain.GetAsStringDangerous(), appInfo.WebsiteBasePath.GetAsStringDangerous(), recipeID, preAuthSessionID, tenantId, linkCode, - ) + ), nil } diff --git a/recipe/passwordless/passwordless_email_test.go b/recipe/passwordless/passwordless_email_test.go index dc2666fc..9ad85df0 100644 --- a/recipe/passwordless/passwordless_email_test.go +++ b/recipe/passwordless/passwordless_email_test.go @@ -17,9 +17,11 @@ package passwordless import ( + "bytes" "encoding/json" "io/ioutil" "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -629,3 +631,116 @@ func TestSMTPServiceOverrideEmailTemplateForMagicLinkAndOtp(t *testing.T) { assert.Equal(t, customCalled, false) assert.Equal(t, sendRawEmailCalled, true) } + +func TestThatMagicLinkUsesRightValueFromOriginFunction(t *testing.T) { + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + + customCalled := false + plessEmail := "" + var code, urlWithCode *string + var codeLife uint64 + + sendEmail := func(input emaildelivery.EmailType, userContext supertokens.UserContext) error { + plessEmail = input.PasswordlessLogin.Email + code = input.PasswordlessLogin.UserInputCode + urlWithCode = input.PasswordlessLogin.UrlWithLinkCode + codeLife = input.PasswordlessLogin.CodeLifetime + customCalled = true + return nil + } + + tplConfig := plessmodels.TypeInput{ + FlowType: "USER_INPUT_CODE_AND_MAGIC_LINK", + EmailDelivery: &emaildelivery.TypeInput{ + Service: &emaildelivery.EmailDeliveryInterface{ + SendEmail: &sendEmail, + }, + }, + ContactMethodEmail: plessmodels.ContactMethodEmailConfig{ + Enabled: true, + }, + } + + config := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + APIDomain: "api.supertokens.io", + AppName: "SuperTokens", + GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + // read request body + decoder := json.NewDecoder(request.Body) + var requestBody map[string]interface{} + err := decoder.Decode(&requestBody) + if err != nil { + return "https://supertokens.com", nil + } + if requestBody["origin"] == nil { + return "https://supertokens.com", nil + } + return requestBody["origin"].(string), nil + }, + }, + RecipeList: []supertokens.Recipe{ + session.Init(nil), + Init(tplConfig), + }, + } + + err := supertokens.Init(config) + assert.NoError(t, err) + + mux := http.NewServeMux() + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer testServer.Close() + + querier, err := supertokens.GetNewQuerierInstanceOrThrowError("") + if err != nil { + t.Error(err.Error()) + } + cdiVersion, err := querier.GetQuerierAPIVersion() + if err != nil { + t.Error(err.Error()) + } + if unittesting.MaxVersion("2.10", cdiVersion) == "2.10" { + return + } + + body := map[string]string{ + "email": "test@example.com", + "origin": "localhost:2000", + } + + postBody, err := json.Marshal(body) + if err != nil { + t.Error(err.Error()) + return + } + + resp, err := http.Post(testServer.URL+"/auth/signinup/code", "application/json", bytes.NewBuffer(postBody)) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + bodyBytes, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + body = map[string]string{} + + err = json.Unmarshal(bodyBytes, &body) + assert.NoError(t, err) + + // Default handler not called + assert.False(t, PasswordlessLoginEmailSentForTest) + assert.Empty(t, PasswordlessLoginEmailDataForTest.Email) + assert.Nil(t, PasswordlessLoginEmailDataForTest.UserInputCode) + assert.Nil(t, PasswordlessLoginEmailDataForTest.UrlWithLinkCode) + + // Custom handler called + assert.Equal(t, plessEmail, "test@example.com") + assert.NotNil(t, code) + assert.Equal(t, (*urlWithCode)[:21], "http://localhost:2000") + assert.NotZero(t, codeLife) + assert.True(t, customCalled) +} diff --git a/recipe/passwordless/recipe.go b/recipe/passwordless/recipe.go index 23af84ed..244a80d3 100644 --- a/recipe/passwordless/recipe.go +++ b/recipe/passwordless/recipe.go @@ -188,7 +188,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } @@ -210,15 +210,17 @@ func (r *Recipe) CreateMagicLink(email *string, phoneNumber *string, tenantId st if err != nil { return "", err } - link := api.GetMagicLink( + link, err := api.GetMagicLink( stInstance.AppInfo, r.RecipeModule.GetRecipeID(), response.OK.PreAuthSessionID, response.OK.LinkCode, tenantId, + supertokens.GetRequestFromUserContext(userContext), + userContext, ) - return link, nil + return link, err } func (r *Recipe) SignInUp(email *string, phoneNumber *string, tenantId string, userContext supertokens.UserContext) (struct { diff --git a/recipe/session/config_test.go b/recipe/session/config_test.go index ab68b6e1..c2d55b61 100644 --- a/recipe/session/config_test.go +++ b/recipe/session/config_test.go @@ -144,7 +144,7 @@ func TestSuperTokensInitWithoutWebsiteDomain(t *testing.T) { defer AfterEach() err := supertokens.Init(configValue) if err != nil { - assert.Equal(t, err.Error(), "Please provide your websiteDomain inside the appInfo object when calling supertokens.init") + assert.Equal(t, err.Error(), "Please provide either Origin, GetOrigin or WebsiteDomain inside the appInfo object when calling supertokens.init") } else { t.Fail() } @@ -437,8 +437,17 @@ func TestSuperTokensInitWithNoneLaxFalseSessionConfigResults(t *testing.T) { if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "NONE") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + antiCsrf, err := sessionSingletonInstance.Config.AntiCsrfFunctionOrString.FunctionValue(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, antiCsrf, "NONE") + assert.True(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue == "") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false) } @@ -475,8 +484,12 @@ func TestSuperTokensInitWithCustomHeaderLaxTrueSessionConfigResults(t *testing.T if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -514,8 +527,12 @@ func TestSuperTokensInitWithCustomHeaderLaxFalseSessionConfigResults(t *testing. if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false) } @@ -548,8 +565,16 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithNormalWe if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none") + anticsrf, err := sessionSingletonInstance.Config.AntiCsrfFunctionOrString.FunctionValue(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, anticsrf, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "none") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -582,8 +607,16 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithLocalWeb if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none") + anticsrf, err := sessionSingletonInstance.Config.AntiCsrfFunctionOrString.FunctionValue(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, anticsrf, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "none") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -619,11 +652,11 @@ func TestSuperTokensWithAntiCSRFNone(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() if err != nil { t.Error(err.Error()) } - assert.Equal(t, singletoneSessionRecipeInstance.Config.AntiCsrf, "NONE") + assert.Equal(t, singletonSessionRecipeInstance.Config.AntiCsrfFunctionOrString.StrValue, "NONE") } func TestSuperTokensWithAntiCSRFRandom(t *testing.T) { @@ -737,12 +770,16 @@ func TestSuperTokensForTheDefaultCookieValues(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true) + cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil) if err != nil { t.Error(err.Error()) } - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true) - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "none") + assert.Equal(t, cookieSameSite, "none") } func TestSuperTokensInitWithWrongConfigSchema(t *testing.T) { @@ -867,15 +904,19 @@ func TestSuperTokensDefaultCookieConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() if err != nil { t.Error(err.Error()) } - assert.Nil(t, singletoneSessionRecipeInstance.Config.CookieDomain) - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "lax") - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true) - assert.Equal(t, singletoneSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh") - assert.Equal(t, singletoneSessionRecipeInstance.Config.SessionExpiredStatusCode, 401) + assert.Nil(t, singletonSessionRecipeInstance.Config.CookieDomain) + cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSite, "lax") + assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true) + assert.Equal(t, singletonSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh") + assert.Equal(t, singletonSessionRecipeInstance.Config.SessionExpiredStatusCode, 401) } func TestSuperTokensInitWithAPIGateWayPath(t *testing.T) { @@ -1256,7 +1297,11 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) { } assert.True(t, recipe.Config.CookieDomain == nil) - assert.Equal(t, recipe.Config.CookieSameSite, "none") + cookieSameSiteValue, err := recipe.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSiteValue, "none") assert.True(t, recipe.Config.CookieSecure) resetAll() @@ -1293,6 +1338,177 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) { } assert.True(t, recipe.Config.CookieDomain == nil) - assert.Equal(t, recipe.Config.CookieSameSite, "lax") + cookieSameSiteValue, err = recipe.Config.GetCookieSameSite(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, cookieSameSiteValue, "lax") assert.False(t, recipe.Config.CookieSecure) } + +func TestInitWorksFineIfOriginIsPresent(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + Origin: "supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + singletonInstance, err := supertokens.GetInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + origin, err := singletonInstance.AppInfo.GetOrigin(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, "https://supertokens.io", origin.GetAsStringDangerous()) +} + +func TestWebsiteDomainWorks(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + WebsiteDomain: "supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + singletonInstance, err := supertokens.GetInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + origin, err := singletonInstance.AppInfo.GetOrigin(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, "https://supertokens.io", origin.GetAsStringDangerous()) +} + +func TestOriginFunctionWorks(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + return "https://test.io", nil + }, + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + singletonInstance, err := supertokens.GetInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + origin, err := singletonInstance.AppInfo.GetOrigin(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, "https://test.io", origin.GetAsStringDangerous()) +} + +func TestOriginIsUsedOverWebsiteDomain(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + Origin: "supertokens.io", + WebsiteDomain: "shouldnotbeused.com", + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + singletonInstance, err := supertokens.GetInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + origin, err := singletonInstance.AppInfo.GetOrigin(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, "https://supertokens.io", origin.GetAsStringDangerous()) +} + +func TestOriginFunctionIsUsedOverOrigin(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + APIDomain: "api.supertokens.io", + Origin: "supertokens.io", + WebsiteDomain: "shouldnotbeused.com", + GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + return "test.io", nil + }, + }, + RecipeList: []supertokens.Recipe{ + Init(nil), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + singletonInstance, err := supertokens.GetInstanceOrThrowError() + if err != nil { + t.Error(err.Error()) + } + origin, err := singletonInstance.AppInfo.GetOrigin(nil, nil) + if err != nil { + t.Error(err.Error()) + } + assert.Equal(t, "https://test.io", origin.GetAsStringDangerous()) +} diff --git a/recipe/session/cookieAndHeaders.go b/recipe/session/cookieAndHeaders.go index e8c9b5f5..056f282b 100644 --- a/recipe/session/cookieAndHeaders.go +++ b/recipe/session/cookieAndHeaders.go @@ -20,13 +20,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/supertokens/supertokens-golang/supertokens" "net/http" "net/textproto" "net/url" "strings" "time" + "github.com/supertokens/supertokens-golang/supertokens" + "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" ) @@ -54,7 +55,7 @@ type TokenInfo struct { Up interface{} `json:"up"` } -func ClearSessionFromAllTokenTransferMethods(config sessmodels.TypeNormalisedInput, req *http.Request, res http.ResponseWriter) error { +func ClearSessionFromAllTokenTransferMethods(config sessmodels.TypeNormalisedInput, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) error { // We are clearing the session in all transfermethods to be sure to override cookies in case they have been already added to the response. // This is done to handle the following use-case: // If the app overrides signInPOST to check the ban status of the user after the original implementation and throwing an UNAUTHORISED error @@ -62,7 +63,7 @@ func ClearSessionFromAllTokenTransferMethods(config sessmodels.TypeNormalisedInp // We can't know which to clear since we can't reliably query or remove the set-cookie header added to the response (causes issues in some frameworks, i.e.: hapi) // The safe solution in this case is to overwrite all the response cookies/headers with an empty value, which is what we are doing here for _, transferMethod := range AvailableTokenTransferMethods { - err := ClearSession(config, res, transferMethod) + err := ClearSession(config, res, transferMethod, req, userContext) if err != nil { return err } @@ -70,11 +71,11 @@ func ClearSessionFromAllTokenTransferMethods(config sessmodels.TypeNormalisedInp return nil } -func ClearSession(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, transferMethod sessmodels.TokenTransferMethod) error { +func ClearSession(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, transferMethod sessmodels.TokenTransferMethod, request *http.Request, userContext supertokens.UserContext) error { // If we can be specific about which transferMethod we want to clear, there is no reason to clear the other ones tokenTypes := []sessmodels.TokenType{sessmodels.AccessToken, sessmodels.RefreshToken} for _, tokenType := range tokenTypes { - err := setToken(config, res, tokenType, "", 0, transferMethod) + err := setToken(config, res, tokenType, "", 0, transferMethod, request, userContext) if err != nil { return err } @@ -159,7 +160,7 @@ func GetToken(req *http.Request, tokenType sessmodels.TokenType, transferMethod return nil, errors.New("Should never happen") } -func setToken(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, tokenType sessmodels.TokenType, value string, expires uint64, transferMethod sessmodels.TokenTransferMethod) error { +func setToken(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, tokenType sessmodels.TokenType, value string, expires uint64, transferMethod sessmodels.TokenTransferMethod, request *http.Request, userContext supertokens.UserContext) error { supertokens.LogDebugMessage(fmt.Sprint("setToken: Setting ", tokenType, " token as ", transferMethod)) if transferMethod == sessmodels.CookieTransferMethod { cookieName, err := getCookieNameFromTokenType(tokenType) @@ -172,7 +173,7 @@ func setToken(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, to } else if tokenType == sessmodels.RefreshToken { pathType = "refreshTokenPath" } - setCookie(config, res, cookieName, value, expires, pathType) + setCookie(config, res, cookieName, value, expires, pathType, request, userContext) } else if transferMethod == sessmodels.HeaderTransferMethod { headerName, err := getResponseHeaderNameForTokenType(tokenType) if err != nil { @@ -204,13 +205,17 @@ func setHeader(res http.ResponseWriter, key, value string, allowDuplicateKey boo } } -func setCookie(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, name string, value string, expires uint64, pathType string) { +func setCookie(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, name string, value string, expires uint64, pathType string, request *http.Request, userContext supertokens.UserContext) error { var domain string if config.CookieDomain != nil { domain = *config.CookieDomain } secure := config.CookieSecure - sameSite := config.CookieSameSite + sameSite, err := config.GetCookieSameSite(request, userContext) + + if err != nil { + return err + } path := "" if pathType == "refreshTokenPath" { @@ -239,7 +244,7 @@ func setCookie(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, n SameSite: sameSiteField, } setCookieValue(res, cookie) - + return nil } func GetAuthmodeFromHeader(req *http.Request) *sessmodels.TokenTransferMethod { diff --git a/recipe/session/middleware.go b/recipe/session/middleware.go index c6125e5d..a813ee66 100644 --- a/recipe/session/middleware.go +++ b/recipe/session/middleware.go @@ -36,7 +36,7 @@ func VerifySessionHelper(recipeInstance Recipe, options *sessmodels.VerifySessio RecipeImplementation: recipeInstance.RecipeImpl, }, userContext) if err != nil { - err = supertokens.ErrorHandler(err, r, dw) + err = supertokens.ErrorHandler(err, r, dw, userContext) if err != nil { recipeInstance.RecipeModule.OnSuperTokensAPIError(err, r, dw) } diff --git a/recipe/session/recipe.go b/recipe/session/recipe.go index 7c15a9f1..68dca1a2 100644 --- a/recipe/session/recipe.go +++ b/recipe/session/recipe.go @@ -56,13 +56,23 @@ func MakeRecipe(recipeId string, appInfo supertokens.NormalisedAppinfo, config * return Recipe{}, configError } - supertokens.LogDebugMessage("session init: AntiCsrf: " + verifiedConfig.AntiCsrf) + if verifiedConfig.AntiCsrfFunctionOrString.FunctionValue != nil { + supertokens.LogDebugMessage("session init: AntiCsrf: function") + } else { + supertokens.LogDebugMessage("session init: AntiCsrf: " + verifiedConfig.AntiCsrfFunctionOrString.StrValue) + } if verifiedConfig.CookieDomain != nil { supertokens.LogDebugMessage("session init: CookieDomain: " + *verifiedConfig.CookieDomain) } else { supertokens.LogDebugMessage("session init: CookieDomain: nil") } - supertokens.LogDebugMessage("session init: CookieSameSite: " + verifiedConfig.CookieSameSite) + // we intentionally use config here instead of verifiedConfig will always + // be a function for getting cookieSameSite. + if config == nil || config.CookieSameSite == nil { + supertokens.LogDebugMessage("session init: CookieSameSite: default function") + } else { + supertokens.LogDebugMessage("session init: CookieSameSite: " + *config.CookieSameSite) + } supertokens.LogDebugMessage("session init: CookieSecure: " + strconv.FormatBool(verifiedConfig.CookieSecure)) supertokens.LogDebugMessage("session init: RefreshTokenPath: " + verifiedConfig.RefreshTokenPath.GetAsStringDangerous()) supertokens.LogDebugMessage("session init: SessionExpiredStatusCode: " + strconv.Itoa(verifiedConfig.SessionExpiredStatusCode)) @@ -173,13 +183,13 @@ func (r *Recipe) getAllCORSHeaders() []string { return resp } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { if defaultErrors.As(err, &errors.UnauthorizedError{}) { supertokens.LogDebugMessage("errorHandler: returning UNAUTHORISED") unauthErr := err.(errors.UnauthorizedError) if unauthErr.ClearTokens == nil || *unauthErr.ClearTokens { supertokens.LogDebugMessage("errorHandler: Clearing tokens because of UNAUTHORISED response") - ClearSessionFromAllTokenTransferMethods(r.Config, req, res) + ClearSessionFromAllTokenTransferMethods(r.Config, req, res, userContext) } return true, r.Config.ErrorHandlers.OnUnauthorised(err.Error(), req, res) } else if defaultErrors.As(err, &errors.TryRefreshTokenError{}) { @@ -187,7 +197,7 @@ func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWrit return true, r.Config.ErrorHandlers.OnTryRefreshToken(err.Error(), req, res) } else if defaultErrors.As(err, &errors.TokenTheftDetectedError{}) { supertokens.LogDebugMessage("errorHandler: clearing tokens because of TOKEN_THEFT_DETECTED response") - ClearSessionFromAllTokenTransferMethods(r.Config, req, res) + ClearSessionFromAllTokenTransferMethods(r.Config, req, res, userContext) errs := err.(errors.TokenTheftDetectedError) return true, r.Config.ErrorHandlers.OnTokenTheftDetected(errs.Payload.SessionHandle, errs.Payload.UserID, req, res) } else if defaultErrors.As(err, &errors.InvalidClaimError{}) { @@ -195,7 +205,7 @@ func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWrit errs := err.(errors.InvalidClaimError) return true, r.Config.ErrorHandlers.OnInvalidClaim(errs.InvalidClaims, req, res) } else { - return r.OpenIdRecipe.RecipeModule.HandleError(err, req, res) + return r.OpenIdRecipe.RecipeModule.HandleError(err, req, res, userContext) } } diff --git a/recipe/session/recipeImplementation.go b/recipe/session/recipeImplementation.go index 8f8ce68f..ca8b9b0e 100644 --- a/recipe/session/recipeImplementation.go +++ b/recipe/session/recipeImplementation.go @@ -175,7 +175,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ // In all cases if sIdRefreshToken token exists (so it's a legacy session) we return TRY_REFRESH_TOKEN. The refresh endpoint will clear this cookie and try to upgrade the session. // Check https://supertokens.com/docs/contribute/decisions/session/0007 for further details and a table of expected behaviours getSession := func(accessTokenString *string, antiCsrfToken *string, options *sessmodels.VerifySessionOptions, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { - if options != nil && options.AntiCsrfCheck != nil && *options.AntiCsrfCheck != false && config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { + if options != nil && options.AntiCsrfCheck != nil && *options.AntiCsrfCheck != false && config.AntiCsrfFunctionOrString.FunctionValue == nil && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_CUSTOM_HEADER { return nil, defaultErrors.New("Since the anti-csrf mode is VIA_CUSTOM_HEADER getSession can't check the CSRF token. Please either use VIA_TOKEN or set antiCsrfCheck to false") } @@ -288,7 +288,7 @@ func MakeRecipeImplementation(querier supertokens.Querier, config sessmodels.Typ } refreshSession := func(refreshToken string, antiCsrfToken *string, disableAntiCsrf bool, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { - if disableAntiCsrf != true && config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { + if disableAntiCsrf != true && config.AntiCsrfFunctionOrString.FunctionValue == nil && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_CUSTOM_HEADER { return nil, defaultErrors.New("Since the anti-csrf mode is VIA_CUSTOM_HEADER getSession can't check the CSRF token. Please either use VIA_TOKEN or set antiCsrfCheck to false") } diff --git a/recipe/session/session.go b/recipe/session/session.go index 6d099881..28232c14 100644 --- a/recipe/session/session.go +++ b/recipe/session/session.go @@ -83,7 +83,7 @@ func newSessionContainer(config sessmodels.TypeNormalisedInput, session *Session // If we instead clear the cookies only when revokeSession // returns true, it can cause this kind of a bug: // https://github.com/supertokens/supertokens-node/issues/343 - ClearSession(config, session.requestResponseInfo.Res, session.requestResponseInfo.TokenTransferMethod) + ClearSession(config, session.requestResponseInfo.Res, session.requestResponseInfo.TokenTransferMethod, session.requestResponseInfo.Req, supertokens.SetRequestInUserContextIfNotDefined(userContext, session.requestResponseInfo.Req)) } return nil } @@ -195,7 +195,7 @@ func newSessionContainer(config sessmodels.TypeNormalisedInput, session *Session session.accessTokenUpdated = true if session.requestResponseInfo != nil { - setTokenErr := SetAccessTokenInResponse(config, session.requestResponseInfo.Res, session.accessToken, session.frontToken, session.requestResponseInfo.TokenTransferMethod) + setTokenErr := SetAccessTokenInResponse(config, session.requestResponseInfo.Res, session.accessToken, session.frontToken, session.requestResponseInfo.TokenTransferMethod, session.requestResponseInfo.Req, supertokens.SetRequestInUserContextIfNotDefined(userContext, session.requestResponseInfo.Req)) if setTokenErr != nil { return setTokenErr } @@ -279,6 +279,32 @@ func newSessionContainer(config sessmodels.TypeNormalisedInput, session *Session return sessionContainer.MergeIntoAccessTokenPayloadWithContext(update, userContext) } + sessionContainer.AttachToRequestResponseWithContext = func(info sessmodels.RequestResponseInfo, userContext supertokens.UserContext) error { + session.requestResponseInfo = &info + + if session.accessTokenUpdated { + err := SetAccessTokenInResponse(config, info.Res, session.accessToken, session.frontToken, info.TokenTransferMethod, session.requestResponseInfo.Req, supertokens.SetRequestInUserContextIfNotDefined(userContext, session.requestResponseInfo.Req)) + + if err != nil { + return err + } + + if session.refreshToken != nil { + err = setToken(config, info.Res, sessmodels.RefreshToken, session.refreshToken.Token, session.refreshToken.Expiry, info.TokenTransferMethod, session.requestResponseInfo.Req, supertokens.SetRequestInUserContextIfNotDefined(userContext, session.requestResponseInfo.Req)) + + if err != nil { + return err + } + } + + if session.antiCSRFToken != nil { + setAntiCsrfTokenInHeaders(info.Res, *session.antiCSRFToken) + } + } + + return nil + } + sessionContainer.RevokeSession = func() error { return sessionContainer.RevokeSessionWithContext(&map[string]interface{}{}) } @@ -347,29 +373,7 @@ func newSessionContainer(config sessmodels.TypeNormalisedInput, session *Session } sessionContainer.AttachToRequestResponse = func(info sessmodels.RequestResponseInfo) error { - session.requestResponseInfo = &info - - if session.accessTokenUpdated { - err := SetAccessTokenInResponse(config, info.Res, session.accessToken, session.frontToken, info.TokenTransferMethod) - - if err != nil { - return err - } - - if session.refreshToken != nil { - err = setToken(config, info.Res, sessmodels.RefreshToken, session.refreshToken.Token, session.refreshToken.Expiry, info.TokenTransferMethod) - - if err != nil { - return err - } - } - - if session.antiCSRFToken != nil { - setAntiCsrfTokenInHeaders(info.Res, *session.antiCSRFToken) - } - } - - return nil + return sessionContainer.AttachToRequestResponseWithContext(info, &map[string]interface{}{}) } return sessionContainer diff --git a/recipe/session/sessionFunctions.go b/recipe/session/sessionFunctions.go index 363d97aa..702c2b76 100644 --- a/recipe/session/sessionFunctions.go +++ b/recipe/session/sessionFunctions.go @@ -37,7 +37,7 @@ func createNewSessionHelper(config sessmodels.TypeNormalisedInput, querier super "userId": userID, "userDataInJWT": AccessTokenPayload, "userDataInDatabase": sessionDataInDatabase, - "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrf == AntiCSRF_VIA_TOKEN, + "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN, "useDynamicSigningKey": config.UseDynamicAccessTokenSigningKey, } @@ -70,7 +70,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens } } - accessTokenInfo, err = GetInfoFromAccessToken(parsedAccessToken, combinedJwks, config.AntiCsrf == AntiCSRF_VIA_TOKEN && doAntiCsrfCheck) + accessTokenInfo, err = GetInfoFromAccessToken(parsedAccessToken, combinedJwks, config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN && doAntiCsrfCheck) if err != nil { if !defaultErrors.As(err, &errors.TryRefreshTokenError{}) { supertokens.LogDebugMessage("getSessionHelper: Returning TryRefreshTokenError because GetInfoFromAccessToken returned an error") @@ -121,7 +121,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens } if doAntiCsrfCheck { - if config.AntiCsrf == AntiCSRF_VIA_TOKEN { + if config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN { if accessTokenInfo != nil { if antiCsrfToken == nil || *antiCsrfToken != *accessTokenInfo.AntiCsrfToken { if antiCsrfToken == nil { @@ -133,7 +133,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens } } } - } else if config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { + } else if config.AntiCsrfFunctionOrString.FunctionValue == nil && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_CUSTOM_HEADER { return sessmodels.GetSessionResponse{}, defaultErrors.New("Please either use VIA_TOKEN, NONE or call with doAntiCsrfCheck false") } } @@ -152,7 +152,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens requestBody := map[string]interface{}{ "accessToken": parsedAccessToken.RawTokenString, "doAntiCsrfCheck": doAntiCsrfCheck, - "enableAntiCsrf": config.AntiCsrf == AntiCSRF_VIA_TOKEN, + "enableAntiCsrf": config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN, "checkDatabase": alwaysCheckCore, } if antiCsrfToken != nil { @@ -227,13 +227,13 @@ func getSessionInformationHelper(querier supertokens.Querier, sessionHandle stri func refreshSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens.Querier, refreshToken string, antiCsrfToken *string, disableAntiCsrf bool, userContext supertokens.UserContext) (sessmodels.CreateOrRefreshAPIResponse, error) { requestBody := map[string]interface{}{ "refreshToken": refreshToken, - "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrf == AntiCSRF_VIA_TOKEN, + "enableAntiCsrf": !disableAntiCsrf && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_TOKEN, } if antiCsrfToken != nil { requestBody["antiCsrfToken"] = *antiCsrfToken } - if config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER && !disableAntiCsrf { + if config.AntiCsrfFunctionOrString.FunctionValue == nil && config.AntiCsrfFunctionOrString.StrValue == AntiCSRF_VIA_CUSTOM_HEADER && !disableAntiCsrf { return sessmodels.CreateOrRefreshAPIResponse{}, defaultErrors.New("Please either use VIA_TOKEN, NONE or call with doAntiCsrfCheck false") } diff --git a/recipe/session/sessionRequestFunctions.go b/recipe/session/sessionRequestFunctions.go index 65db1a8b..c542a804 100644 --- a/recipe/session/sessionRequestFunctions.go +++ b/recipe/session/sessionRequestFunctions.go @@ -69,16 +69,27 @@ func CreateNewSessionInRequest(req *http.Request, res http.ResponseWriter, tenan if err != nil { return nil, err } - isTopLevelWebsiteDomainIPAddress, err := supertokens.IsAnIPAddress(appInfo.TopLevelWebsiteDomain) + + topLevelWebsiteDomain, err := appInfo.GetTopLevelWebsiteDomain(req, userContext) + if err != nil { + return nil, err + } + + isTopLevelWebsiteDomainIPAddress, err := supertokens.IsAnIPAddress(topLevelWebsiteDomain) + if err != nil { + return nil, err + } + + cookieSameSite, err := config.GetCookieSameSite(req, userContext) if err != nil { return nil, err } if outputTokenTransferMethod == sessmodels.CookieTransferMethod && - config.CookieSameSite == "none" && + cookieSameSite == "none" && !config.CookieSecure && !((appInfo.TopLevelAPIDomain == "localhost" || isTopLevelAPIDomainIPAddress) && - (appInfo.TopLevelWebsiteDomain == "localhost" || isTopLevelWebsiteDomainIPAddress)) { + (topLevelWebsiteDomain == "localhost" || isTopLevelWebsiteDomainIPAddress)) { // We can allow insecure cookie when both website & API domain are localhost or an IP // When either of them is a different domain, API domain needs to have https and a secure cookie to work return nil, defaultErrors.New("Since your API and website domain are different, for sessions to work, please use https on your apiDomain and dont set cookieSecure to false.") @@ -101,18 +112,18 @@ func CreateNewSessionInRequest(req *http.Request, res http.ResponseWriter, tenan return nil, err } if token != nil { - ClearSession(config, res, tokenTransferMethod) + ClearSession(config, res, tokenTransferMethod, req, userContext) } } } supertokens.LogDebugMessage("createNewSession: Cleared old tokens") - sessionResponse.AttachToRequestResponse(sessmodels.RequestResponseInfo{ + sessionResponse.AttachToRequestResponseWithContext(sessmodels.RequestResponseInfo{ Res: res, Req: req, TokenTransferMethod: outputTokenTransferMethod, - }) + }, userContext) supertokens.LogDebugMessage("createNewSession: Attached new tokens to res") return sessionResponse, nil @@ -192,8 +203,17 @@ func GetSessionFromRequest(req *http.Request, res http.ResponseWriter, config se doAntiCsrfCheck = &False } - if *doAntiCsrfCheck && config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { - if config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { + antiCsrf := config.AntiCsrfFunctionOrString.StrValue + if antiCsrf == "" { + antiCsrfTemp, err := config.AntiCsrfFunctionOrString.FunctionValue(req, userContext) + if err != nil { + return nil, err + } + antiCsrf = antiCsrfTemp + } + + if *doAntiCsrfCheck && antiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { + if antiCsrf == AntiCSRF_VIA_CUSTOM_HEADER { if GetRidFromHeader(req) == nil { supertokens.LogDebugMessage("getSession: Returning TRY_REFRESH_TOKEN because custom header (rid) was not passed") return nil, errors.TryRefreshTokenError{ @@ -268,11 +288,11 @@ func GetSessionFromRequest(req *http.Request, res http.ResponseWriter, config se transferMethod = allowedTokenTransferMethod } - err = (*sessionResult).AttachToRequestResponse(sessmodels.RequestResponseInfo{ + err = (*sessionResult).AttachToRequestResponseWithContext(sessmodels.RequestResponseInfo{ Res: res, Req: req, TokenTransferMethod: transferMethod, - }) + }, userContext) if err != nil { return nil, err @@ -316,7 +336,7 @@ func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config } else { if GetCookieValue(req, legacyIdRefreshTokenCookieName) != nil { supertokens.LogDebugMessage("refreshSession: cleared legacy id refresh token because refresh token was not found") - setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath") + setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath", req, userContext) } supertokens.LogDebugMessage("refreshSession: UNAUTHORISED because refresh token in request is undefined") @@ -329,7 +349,16 @@ func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config antiCsrfToken := GetAntiCsrfTokenFromHeaders(req) disableAntiCSRF := requestTokenTransferMethod == sessmodels.HeaderTransferMethod - if config.AntiCsrf == AntiCSRF_VIA_CUSTOM_HEADER && !disableAntiCSRF { + antiCsrf := config.AntiCsrfFunctionOrString.StrValue + if antiCsrf == "" { + antiCsrfTemp, err := config.AntiCsrfFunctionOrString.FunctionValue(req, userContext) + if err != nil { + return nil, err + } + antiCsrf = antiCsrfTemp + } + + if antiCsrf == AntiCSRF_VIA_CUSTOM_HEADER && !disableAntiCSRF { ridFromHeader := GetRidFromHeader(req) if ridFromHeader == nil { @@ -355,7 +384,7 @@ func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config if (isTokenTheftDetectedErr) || (isUnauthorisedErr && unauthorisedErr.ClearTokens != nil && *unauthorisedErr.ClearTokens) { if GetCookieValue(req, legacyIdRefreshTokenCookieName) != nil { supertokens.LogDebugMessage("refreshSession: cleared legacy id refresh token because refresh is clearing other tokens") - setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath") + setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath", req, userContext) } } @@ -370,21 +399,21 @@ func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config for _, tokenTransferMethod := range AvailableTokenTransferMethods { if tokenTransferMethod != requestTokenTransferMethod && refreshTokens[tokenTransferMethod] != nil { - ClearSession(config, res, tokenTransferMethod) + ClearSession(config, res, tokenTransferMethod, req, userContext) } } - (*result).AttachToRequestResponse(sessmodels.RequestResponseInfo{ + (*result).AttachToRequestResponseWithContext(sessmodels.RequestResponseInfo{ Res: res, Req: req, TokenTransferMethod: requestTokenTransferMethod, - }) + }, userContext) supertokens.LogDebugMessage("refreshSession: Success!") if GetCookieValue(req, legacyIdRefreshTokenCookieName) != nil { supertokens.LogDebugMessage("refreshSession: cleared legacy id refresh token after successful refresh") - setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath") + setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath", req, userContext) } return result, nil diff --git a/recipe/session/session_test.go b/recipe/session/session_test.go index 1ee7763a..62fa5012 100644 --- a/recipe/session/session_test.go +++ b/recipe/session/session_test.go @@ -1915,6 +1915,7 @@ func TestSessionVerificationOfJWTBasedOnSessionPayloadWithCheckDatabase(t *testi delete(payload, "iat") delete(payload, "exp") payload["tId"] = "public" + payload["rsub"] = session.GetUserID() currentTimeInSeconds := time.Now() jwtExpiry := uint64((currentTimeInSeconds.Add(10 * time.Second)).Unix()) diff --git a/recipe/session/sessmodels/models.go b/recipe/session/sessmodels/models.go index 8c9dafbc..75e8b630 100644 --- a/recipe/session/sessmodels/models.go +++ b/recipe/session/sessmodels/models.go @@ -127,11 +127,11 @@ type ErrorHandlers struct { type TypeNormalisedInput struct { RefreshTokenPath supertokens.NormalisedURLPath CookieDomain *string - CookieSameSite string + GetCookieSameSite func(request *http.Request, userContext supertokens.UserContext) (string, error) CookieSecure bool SessionExpiredStatusCode int InvalidClaimStatusCode int - AntiCsrf string + AntiCsrfFunctionOrString AntiCsrfFunctionOrString Override OverrideStruct ErrorHandlers NormalisedErrorHandlers GetTokenTransferMethod func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) TokenTransferMethod @@ -139,6 +139,11 @@ type TypeNormalisedInput struct { UseDynamicAccessTokenSigningKey bool } +type AntiCsrfFunctionOrString struct { + StrValue string + FunctionValue func(request *http.Request, userContext supertokens.UserContext) (string, error) +} + type JWTNormalisedConfig struct { Issuer *string Enable bool @@ -210,11 +215,12 @@ type TypeSessionContainer struct { MergeIntoAccessTokenPayloadWithContext func(accessTokenPayloadUpdate map[string]interface{}, userContext supertokens.UserContext) error - AssertClaimsWithContext func(claimValidators []claims.SessionClaimValidator, userContext supertokens.UserContext) error - FetchAndSetClaimWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) error - SetClaimValueWithContext func(claim *claims.TypeSessionClaim, value interface{}, userContext supertokens.UserContext) error - GetClaimValueWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) interface{} - RemoveClaimWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) error + AssertClaimsWithContext func(claimValidators []claims.SessionClaimValidator, userContext supertokens.UserContext) error + FetchAndSetClaimWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) error + SetClaimValueWithContext func(claim *claims.TypeSessionClaim, value interface{}, userContext supertokens.UserContext) error + GetClaimValueWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) interface{} + RemoveClaimWithContext func(claim *claims.TypeSessionClaim, userContext supertokens.UserContext) error + AttachToRequestResponseWithContext func(info RequestResponseInfo, userContext supertokens.UserContext) error MergeIntoAccessTokenPayload func(accessTokenPayloadUpdate map[string]interface{}) error diff --git a/recipe/session/utils.go b/recipe/session/utils.go index 97f6d7dd..7dc287be 100644 --- a/recipe/session/utils.go +++ b/recipe/session/utils.go @@ -57,25 +57,41 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config } } - apiDomainScheme, err := GetURLScheme(appInfo.APIDomain.GetAsStringDangerous()) - if err != nil { - return sessmodels.TypeNormalisedInput{}, err - } - websiteDomainScheme, err := GetURLScheme(appInfo.WebsiteDomain.GetAsStringDangerous()) - if err != nil { - return sessmodels.TypeNormalisedInput{}, err + if config != nil && config.CookieSameSite != nil { + // we have this block just to check if the user input is correct + _, err = normaliseSameSiteOrThrowError(*config.CookieSameSite) + if err != nil { + return sessmodels.TypeNormalisedInput{}, err + } } - cookieSameSite := CookieSameSite_LAX - if apiDomainScheme != websiteDomainScheme || appInfo.TopLevelAPIDomain != appInfo.TopLevelWebsiteDomain { - cookieSameSite = CookieSameSite_NONE - } + cookieSameSite := func(request *http.Request, userContext supertokens.UserContext) (string, error) { + if config != nil && config.CookieSameSite != nil { + return normaliseSameSiteOrThrowError(*config.CookieSameSite) + } + origin, err := appInfo.GetOrigin(request, userContext) + if err != nil { + return "", err + } + protocolOfWebsiteDomain, err := GetURLScheme(origin.GetAsStringDangerous()) + if err != nil { + return "", err + } - if config != nil && config.CookieSameSite != nil { - cookieSameSite, err = normaliseSameSiteOrThrowError(*config.CookieSameSite) + protocolOfAPIDomain, err := GetURLScheme(appInfo.APIDomain.GetAsStringDangerous()) if err != nil { - return sessmodels.TypeNormalisedInput{}, err + return "", err + } + + topLevelWebsiteDomain, err := appInfo.GetTopLevelWebsiteDomain(request, userContext) + if err != nil { + return "", err } + + if protocolOfAPIDomain != protocolOfWebsiteDomain || appInfo.TopLevelAPIDomain != topLevelWebsiteDomain { + return CookieSameSite_NONE, nil + } + return CookieSameSite_LAX, nil } cookieSecure := false @@ -99,21 +115,32 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config return sessmodels.TypeNormalisedInput{}, errors.New("SessionExpiredStatusCode and InvalidClaimStatusCode cannot have the same value") } + antiCsrfFunctionOrString := sessmodels.AntiCsrfFunctionOrString{ + FunctionValue: func(request *http.Request, userContext supertokens.UserContext) (string, error) { + sameSite, err := cookieSameSite(request, userContext) + if err != nil { + return "", err + } + if sameSite == CookieSameSite_NONE { + return AntiCSRF_VIA_CUSTOM_HEADER, nil + } + return AntiCSRF_NONE, nil + }, + } if config != nil && config.AntiCsrf != nil { if *config.AntiCsrf != AntiCSRF_NONE && *config.AntiCsrf != AntiCSRF_VIA_CUSTOM_HEADER && *config.AntiCsrf != AntiCSRF_VIA_TOKEN { return sessmodels.TypeNormalisedInput{}, errors.New("antiCsrf config must be one of 'NONE' or 'VIA_CUSTOM_HEADER' or 'VIA_TOKEN'") } + antiCsrfFunctionOrString = sessmodels.AntiCsrfFunctionOrString{ + StrValue: *config.AntiCsrf, + } } - antiCsrf := AntiCSRF_NONE - if config == nil || config.AntiCsrf == nil { - if cookieSameSite == CookieSameSite_NONE { - antiCsrf = AntiCSRF_VIA_CUSTOM_HEADER - } else { - antiCsrf = AntiCSRF_NONE - } - } else { - antiCsrf = *config.AntiCsrf + if antiCsrfFunctionOrString.FunctionValue != nil && antiCsrfFunctionOrString.StrValue != "" { + return sessmodels.TypeNormalisedInput{}, errors.New("should never come here") + } + if antiCsrfFunctionOrString.FunctionValue == nil && antiCsrfFunctionOrString.StrValue == "" { + return sessmodels.TypeNormalisedInput{}, errors.New("should never come here") } errorHandlers := sessmodels.NormalisedErrorHandlers{ @@ -181,11 +208,11 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config typeNormalisedInput := sessmodels.TypeNormalisedInput{ RefreshTokenPath: appInfo.APIBasePath.AppendPath(refreshAPIPath), CookieDomain: cookieDomain, - CookieSameSite: cookieSameSite, + GetCookieSameSite: cookieSameSite, CookieSecure: cookieSecure, SessionExpiredStatusCode: sessionExpiredStatusCode, InvalidClaimStatusCode: invalidClaimStatusCode, - AntiCsrf: antiCsrf, + AntiCsrfFunctionOrString: antiCsrfFunctionOrString, ExposeAccessTokenToFrontendInCookieBasedAuth: config.ExposeAccessTokenToFrontendInCookieBasedAuth, UseDynamicAccessTokenSigningKey: useDynamicSigningKey, ErrorHandlers: errorHandlers, @@ -268,20 +295,20 @@ func GetCurrTimeInMS() uint64 { return uint64(time.Now().UnixNano() / 1000000) } -func SetAccessTokenInResponse(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, accessToken string, frontToken string, tokenTransferMethod sessmodels.TokenTransferMethod) error { +func SetAccessTokenInResponse(config sessmodels.TypeNormalisedInput, res http.ResponseWriter, accessToken string, frontToken string, tokenTransferMethod sessmodels.TokenTransferMethod, request *http.Request, userContext supertokens.UserContext) error { setFrontTokenInHeaders(res, frontToken) // We set the expiration to 100 years, because we can't really access the expiration of the refresh token everywhere we are setting it. // This should be safe to do, since this is only the validity of the cookie (set here or on the frontend) but we check the expiration of the JWT anyway. // Even if the token is expired the presence of the token indicates that the user could have a valid refresh // Setting them to infinity would require special case handling on the frontend and just adding 100 years seems enough. - setToken(config, res, sessmodels.AccessToken, accessToken, GetCurrTimeInMS()+uint64(accessTokenCookiesExpiryDurationMillis), tokenTransferMethod) + setToken(config, res, sessmodels.AccessToken, accessToken, GetCurrTimeInMS()+uint64(accessTokenCookiesExpiryDurationMillis), tokenTransferMethod, request, userContext) if config.ExposeAccessTokenToFrontendInCookieBasedAuth && tokenTransferMethod == sessmodels.CookieTransferMethod { // We set the expiration to 100 years, because we can't really access the expiration of the refresh token everywhere we are setting it. // This should be safe to do, since this is only the validity of the cookie (set here or on the frontend) but we check the expiration of the JWT anyway. // Even if the token is expired the presence of the token indicates that the user could have a valid refresh // Setting them to infinity would require special case handling on the frontend and just adding 100 years seems enough. - setToken(config, res, sessmodels.AccessToken, accessToken, GetCurrTimeInMS()+uint64(accessTokenCookiesExpiryDurationMillis), sessmodels.HeaderTransferMethod) + setToken(config, res, sessmodels.AccessToken, accessToken, GetCurrTimeInMS()+uint64(accessTokenCookiesExpiryDurationMillis), sessmodels.HeaderTransferMethod, request, userContext) } return nil } diff --git a/recipe/thirdparty/recipe.go b/recipe/thirdparty/recipe.go index d564ae44..82db5460 100644 --- a/recipe/thirdparty/recipe.go +++ b/recipe/thirdparty/recipe.go @@ -155,7 +155,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { if errors.As(err, &tperrors.ClientTypeNotFoundError{}) { supertokens.SendNon200ResponseWithMessage(res, err.Error(), 400) return true, nil diff --git a/recipe/thirdpartyemailpassword/main.go b/recipe/thirdpartyemailpassword/main.go index 52abbb0f..c297bf89 100644 --- a/recipe/thirdpartyemailpassword/main.go +++ b/recipe/thirdpartyemailpassword/main.go @@ -173,14 +173,22 @@ func CreateResetPasswordLink(tenantId string, userID string, userContext ...supe return epmodels.CreateResetPasswordLinkResponse{}, err } + link, err := api.GetPasswordResetLink( + instance.RecipeModule.GetAppInfo(), + instance.RecipeModule.GetRecipeID(), + tokenResponse.OK.Token, + tenantId, + supertokens.GetRequestFromUserContext(userContext[0]), + userContext[0], + ) + + if err != nil { + return epmodels.CreateResetPasswordLinkResponse{}, err + } + return epmodels.CreateResetPasswordLinkResponse{ OK: &struct{ Link string }{ - Link: api.GetPasswordResetLink( - instance.RecipeModule.GetAppInfo(), - instance.RecipeModule.GetRecipeID(), - tokenResponse.OK.Token, - tenantId, - ), + Link: link, }, }, nil } diff --git a/recipe/thirdpartyemailpassword/recipe.go b/recipe/thirdpartyemailpassword/recipe.go index 3743dd1e..382f385b 100644 --- a/recipe/thirdpartyemailpassword/recipe.go +++ b/recipe/thirdpartyemailpassword/recipe.go @@ -201,13 +201,13 @@ func (r *Recipe) getAllCORSHeaders() []string { return corsHeaders } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } if r.thirdPartyRecipe != nil { - handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res) + handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } diff --git a/recipe/thirdpartypasswordless/recipe.go b/recipe/thirdpartypasswordless/recipe.go index 8b03c1fa..e4566516 100644 --- a/recipe/thirdpartypasswordless/recipe.go +++ b/recipe/thirdpartypasswordless/recipe.go @@ -204,13 +204,13 @@ func (r *Recipe) getAllCORSHeaders() []string { return corsHeaders } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } if r.thirdPartyRecipe != nil { - handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res) + handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } diff --git a/recipe/usermetadata/recipe.go b/recipe/usermetadata/recipe.go index 2fb0cec2..4923e017 100644 --- a/recipe/usermetadata/recipe.go +++ b/recipe/usermetadata/recipe.go @@ -86,7 +86,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/userroles/recipe.go b/recipe/userroles/recipe.go index c49af9ba..2b961a6f 100644 --- a/recipe/userroles/recipe.go +++ b/recipe/userroles/recipe.go @@ -106,7 +106,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/supertokens/constants.go b/supertokens/constants.go index 88141066..2f4f72de 100644 --- a/supertokens/constants.go +++ b/supertokens/constants.go @@ -21,7 +21,7 @@ const ( ) // VERSION current version of the lib -const VERSION = "0.17.0" +const VERSION = "0.17.1" var ( cdiSupported = []string{"3.0"} diff --git a/supertokens/main.go b/supertokens/main.go index 338919b1..df5a8a8b 100644 --- a/supertokens/main.go +++ b/supertokens/main.go @@ -41,12 +41,15 @@ func Middleware(theirHandler http.Handler) http.Handler { return instance.middleware(theirHandler) } -func ErrorHandler(err error, req *http.Request, res http.ResponseWriter) error { +func ErrorHandler(err error, req *http.Request, res http.ResponseWriter, userContext ...UserContext) error { instance, instanceErr := GetInstanceOrThrowError() if instanceErr != nil { return instanceErr } - return instance.errorHandler(err, req, res) + if len(userContext) == 0 { + userContext = append(userContext, &map[string]interface{}{}) + } + return instance.errorHandler(err, req, res, userContext[0]) } func GetAllCORSHeaders() []string { diff --git a/supertokens/models.go b/supertokens/models.go index 2a89cd96..7f84792e 100644 --- a/supertokens/models.go +++ b/supertokens/models.go @@ -20,19 +20,21 @@ import ( ) type NormalisedAppinfo struct { - AppName string - WebsiteDomain NormalisedURLDomain - APIDomain NormalisedURLDomain - TopLevelAPIDomain string - TopLevelWebsiteDomain string - APIBasePath NormalisedURLPath - APIGatewayPath NormalisedURLPath - WebsiteBasePath NormalisedURLPath + AppName string + GetOrigin func(request *http.Request, userContext UserContext) (NormalisedURLDomain, error) + APIDomain NormalisedURLDomain + TopLevelAPIDomain string + GetTopLevelWebsiteDomain func(request *http.Request, userContext UserContext) (string, error) + APIBasePath NormalisedURLPath + APIGatewayPath NormalisedURLPath + WebsiteBasePath NormalisedURLPath } type AppInfo struct { AppName string WebsiteDomain string + Origin string + GetOrigin func(request *http.Request, userContext UserContext) (string, error) APIDomain string WebsiteBasePath *string APIBasePath *string diff --git a/supertokens/querier.go b/supertokens/querier.go index a9b6d265..599b72e3 100644 --- a/supertokens/querier.go +++ b/supertokens/querier.go @@ -132,13 +132,13 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}, user return nil, err } - apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + apiVersion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { return nil, querierAPIVersionError } req.Header.Set("content-type", "application/json; charset=utf-8") - req.Header.Set("cdi-version", apiVerion) + req.Header.Set("cdi-version", apiVersion) if QuerierAPIKey != nil { req.Header.Set("api-key", *QuerierAPIKey) } @@ -178,13 +178,13 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa } req.URL.RawQuery = query.Encode() - apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + apiVersion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { return nil, querierAPIVersionError } req.Header.Set("content-type", "application/json; charset=utf-8") - req.Header.Set("cdi-version", apiVerion) + req.Header.Set("cdi-version", apiVersion) if QuerierAPIKey != nil { req.Header.Set("api-key", *QuerierAPIKey) } @@ -220,11 +220,11 @@ func (q *Querier) SendGetRequest(path string, params map[string]string, userCont } req.URL.RawQuery = query.Encode() - apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + apiVersion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { return nil, querierAPIVersionError } - req.Header.Set("cdi-version", apiVerion) + req.Header.Set("cdi-version", apiVersion) if QuerierAPIKey != nil { req.Header.Set("api-key", *QuerierAPIKey) } @@ -261,11 +261,11 @@ func (q *Querier) SendGetRequestWithResponseHeaders(path string, params map[stri } req.URL.RawQuery = query.Encode() - apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + apiVersion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { return nil, querierAPIVersionError } - req.Header.Set("cdi-version", apiVerion) + req.Header.Set("cdi-version", apiVersion) if QuerierAPIKey != nil { req.Header.Set("api-key", *QuerierAPIKey) } @@ -297,13 +297,13 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}, userC return nil, err } - apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + apiVersion, querierAPIVersionError := q.GetQuerierAPIVersion() if querierAPIVersionError != nil { return nil, querierAPIVersionError } req.Header.Set("content-type", "application/json; charset=utf-8") - req.Header.Set("cdi-version", apiVerion) + req.Header.Set("cdi-version", apiVersion) if QuerierAPIKey != nil { req.Header.Set("api-key", *QuerierAPIKey) } diff --git a/supertokens/recipeModule.go b/supertokens/recipeModule.go index 9a121e96..d7bebe8f 100644 --- a/supertokens/recipeModule.go +++ b/supertokens/recipeModule.go @@ -28,7 +28,7 @@ type RecipeModule struct { GetAllCORSHeaders func() []string GetAPIsHandled func() ([]APIHandled, error) ReturnAPIIdIfCanHandleRequest func(path NormalisedURLPath, method string, userContext UserContext) (*string, string, error) - HandleError func(err error, req *http.Request, res http.ResponseWriter) (bool, error) + HandleError func(err error, req *http.Request, res http.ResponseWriter, userContext UserContext) (bool, error) OnSuperTokensAPIError func(err error, req *http.Request, res http.ResponseWriter) } @@ -39,7 +39,7 @@ func MakeRecipeModule( getAllCORSHeaders func() []string, getAPIsHandled func() ([]APIHandled, error), returnAPIIdIfCanHandleRequest func(path NormalisedURLPath, method string, userContext UserContext) (*string, string, error), - handleError func(err error, req *http.Request, res http.ResponseWriter) (bool, error), + handleError func(err error, req *http.Request, res http.ResponseWriter, userContext UserContext) (bool, error), onSuperTokensAPIError func(err error, req *http.Request, res http.ResponseWriter)) RecipeModule { if handleError == nil { // Execution will come here only if there is a bug in the code diff --git a/supertokens/supertokens.go b/supertokens/supertokens.go index a846e82f..b3bc166e 100644 --- a/supertokens/supertokens.go +++ b/supertokens/supertokens.go @@ -58,7 +58,20 @@ func supertokensInit(config TypeInput) error { LogDebugMessage("Started SuperTokens with debug logging (supertokens.Init called)") - appInfoJsonString, _ := json.Marshal(config.AppInfo) + // we do this below because we cannot marshal a function. + jsonableStruct := map[string]interface{}{ + "AppName": config.AppInfo.AppName, + "Origin": config.AppInfo.Origin, + "WebsiteDomain": config.AppInfo.WebsiteDomain, + "APIDomain": config.AppInfo.APIDomain, + "WebsiteBasePath": config.AppInfo.WebsiteBasePath, + "APIBasePath": config.AppInfo.APIBasePath, + "APIGatewayPath": config.AppInfo.APIGatewayPath, + } + if config.AppInfo.GetOrigin != nil { + jsonableStruct["Origin"] = "function" + } + appInfoJsonString, _ := json.Marshal(jsonableStruct) LogDebugMessage("AppInfo: " + string(appInfoJsonString)) var err error @@ -147,7 +160,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { userContext := MakeDefaultUserContextFromAPI(r) reqURL, err := NewNormalisedURLPath(r.URL.Path) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -187,7 +200,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { id, tenantId, err := matchedRecipe.ReturnAPIIdIfCanHandleRequest(path, method, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -204,7 +217,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { tenantId, err = GetTenantIdFuncFromUsingMultitenancyRecipe(tenantId, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -213,7 +226,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { apiErr := matchedRecipe.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) if apiErr != nil { - apiErr = s.errorHandler(apiErr, r, dw) + apiErr = s.errorHandler(apiErr, r, dw, userContext) if apiErr != nil && !dw.IsDone() { s.OnSuperTokensAPIError(apiErr, r, dw) } @@ -225,7 +238,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { id, tenantId, err := recipeModule.ReturnAPIIdIfCanHandleRequest(path, method, userContext) LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID()) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -236,7 +249,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id) err := recipeModule.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -268,7 +281,7 @@ func (s *superTokens) getAllCORSHeaders() []string { return headers } -func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter) error { +func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter, userContext UserContext) error { LogDebugMessage("errorHandler: Started") if errors.As(originalError, &BadInputError{}) { LogDebugMessage("errorHandler: Sending 400 status code response") @@ -286,7 +299,7 @@ func (s *superTokens) errorHandler(originalError error, req *http.Request, res h LogDebugMessage("errorHandler: Checking recipe for match: " + recipe.recipeID) if recipe.HandleError != nil { LogDebugMessage("errorHandler: Matched with recipeId: " + recipe.recipeID) - handled, err := recipe.HandleError(originalError, req, res) + handled, err := recipe.HandleError(originalError, req, res, userContext) if err != nil { return err } diff --git a/supertokens/utils.go b/supertokens/utils.go index f0e8c027..059b6150 100644 --- a/supertokens/utils.go +++ b/supertokens/utils.go @@ -47,9 +47,6 @@ func NormaliseInputAppInfoOrThrowError(appInfo AppInfo) (NormalisedAppinfo, erro if appInfo.AppName == "" { return NormalisedAppinfo{}, errors.New("Please provide your appName inside the appInfo object when calling supertokens.init") } - if appInfo.WebsiteDomain == "" { - return NormalisedAppinfo{}, errors.New("Please provide your websiteDomain inside the appInfo object when calling supertokens.init") - } apiGatewayPath, err := NewNormalisedURLPath("") if err != nil { return NormalisedAppinfo{}, err @@ -60,10 +57,28 @@ func NormaliseInputAppInfoOrThrowError(appInfo AppInfo) (NormalisedAppinfo, erro return NormalisedAppinfo{}, err } } - websiteDomain, err := NewNormalisedURLDomain(appInfo.WebsiteDomain) - if err != nil { - return NormalisedAppinfo{}, err + + if appInfo.Origin == "" && appInfo.WebsiteDomain == "" && appInfo.GetOrigin == nil { + return NormalisedAppinfo{}, errors.New("Please provide either Origin, GetOrigin or WebsiteDomain inside the appInfo object when calling supertokens.init") } + + websiteDomainFunction := func(request *http.Request, userContext UserContext) (NormalisedURLDomain, error) { + origin := appInfo.Origin + if origin == "" { + origin = appInfo.WebsiteDomain + } + + if appInfo.GetOrigin != nil { + originResult, err := appInfo.GetOrigin(request, userContext) + if err != nil { + return NormalisedURLDomain{}, err + } + origin = originResult + } + + return NewNormalisedURLDomain(origin) + } + apiDomain, err := NewNormalisedURLDomain(appInfo.APIDomain) if err != nil { return NormalisedAppinfo{}, err @@ -73,9 +88,13 @@ func NormaliseInputAppInfoOrThrowError(appInfo AppInfo) (NormalisedAppinfo, erro if err != nil { return NormalisedAppinfo{}, err } - topLevelWebsiteDomain, err := GetTopLevelDomainForSameSiteResolution(websiteDomain.GetAsStringDangerous()) - if err != nil { - return NormalisedAppinfo{}, err + + getTopLevelWebsiteDomain := func(request *http.Request, userContext UserContext) (string, error) { + origin, err := websiteDomainFunction(request, userContext) + if err != nil { + return "", err + } + return GetTopLevelDomainForSameSiteResolution(origin.GetAsStringDangerous()) } APIBasePathStr := "/auth" @@ -97,14 +116,14 @@ func NormaliseInputAppInfoOrThrowError(appInfo AppInfo) (NormalisedAppinfo, erro return NormalisedAppinfo{}, err } return NormalisedAppinfo{ - AppName: appInfo.AppName, - APIGatewayPath: apiGatewayPath, - WebsiteDomain: websiteDomain, - APIDomain: apiDomain, - APIBasePath: apiBasePath, - TopLevelAPIDomain: topLevelAPIDomain, - TopLevelWebsiteDomain: topLevelWebsiteDomain, - WebsiteBasePath: websiteBasePath, + AppName: appInfo.AppName, + APIGatewayPath: apiGatewayPath, + GetOrigin: websiteDomainFunction, + APIDomain: apiDomain, + APIBasePath: apiBasePath, + TopLevelAPIDomain: topLevelAPIDomain, + GetTopLevelWebsiteDomain: getTopLevelWebsiteDomain, + WebsiteBasePath: websiteBasePath, }, nil }