diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b954343..6c14a44a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,44 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.19.0] - 2024-05-01 + +- Added `OlderCookieDomain` config option in the session recipe. This will allow users to clear cookies from the older domain when the `CookieDomain` is changed. +- If `VerifySession` detects multiple access tokens in the request, it will return a 401 error, prompting a refresh, even if one of the tokens is valid. +- `RefreshPOST` (`/auth/session/refresh` by default) API changes: + - now returns 500 error if multiple access tokens are present in the request and `config.OlderCookieDomain` is not set. + - now clears the access token cookie if it was called without a refresh token (if an access token cookie exists and if using cookie-based sessions). + - now clears cookies from the old domain if `OlderCookieDomain` is specified and multiple refresh/access token cookies exist, without updating the front-token or any of the tokens. + - now a 200 response may not include new session tokens. +- Fixed a bug in the `normaliseSessionScopeOrThrowError` util function that caused it to remove leading dots from the scope string. + +### Rationale + +This update addresses an edge case where changing the `CookieDomain` config on the server can lead to session integrity issues. For instance, if the API server URL is 'api.example.com' with a cookie domain of '.example.com', and the server updates the cookie domain to 'api.example.com', the client may retain cookies with both '.example.com' and 'api.example.com' domains, resulting in multiple sets of session token cookies existing. + +Previously, verifySession would select one of the access tokens from the incoming request. If it chose the older cookie, it would return a 401 status code, prompting a refresh request. However, the `RefreshPOST` API would then set new session token cookies with the updated `CookieDomain`, but older cookies will persist, leading to repeated 401 errors and refresh loops. + +With this update, verifySession will return a 401 error if it detects multiple access tokens in the request, prompting a refresh request. The `RefreshPOST` API will clear cookies from the old domain if `OlderCookieDomain` is specified in the configuration, then return a 200 status. If `OlderCookieDomain` is not configured, the `RefreshPOST` API will return a 500 error with a message instructing to set `OlderCookieDomain`. + + +**Example:** + +- `APIDomain`: 'api.example.com' +- `CookieDomain`: 'api.example.com' + +**Flow:** + +1. After authentication, the frontend has cookies set with `domain=api.example.com`, but the access token has expired. +2. The server updates `CookieDomain` to `.example.com`. +3. An API call requiring session with an expired access token (cookie with `domain=api.example.com`) results in a 401 response. +4. The frontend attempts to refresh the session, generating a new access token saved with `domain=.example.com`. +5. The original API call is retried, but because it sends both the old and new cookies, it again results in a 401 response. +6. The frontend tries to refresh the session with multiple access tokens: + - If `OlderCookieDomain` is not set, the refresh fails with a 500 error. + - The user remains stuck until they clear cookies manually or `OlderCookieDomain` is set. + - If `OlderCookieDomain` is set, the refresh clears the older cookie, returning a 200 response. + - The frontend retries the original API call, sending only the new cookie (`domain=.example.com`), resulting in a successful request. + ## [0.18.0] - 2024-04-30 ### Changes diff --git a/recipe/session/cookieAndHeaders.go b/recipe/session/cookieAndHeaders.go index 056f282b..34fb2ea3 100644 --- a/recipe/session/cookieAndHeaders.go +++ b/recipe/session/cookieAndHeaders.go @@ -26,6 +26,8 @@ import ( "strings" "time" + sessionError "github.com/supertokens/supertokens-golang/recipe/session/errors" + "github.com/supertokens/supertokens-golang/supertokens" "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" @@ -316,3 +318,68 @@ func getCookieName(cookie string) string { } return kv[0] } + +// ClearSessionCookiesFromOlderCookieDomain addresses an edge case where changing the cookieDomain config on the server can +// lead to session integrity issues. For instance, if the API server URL is 'api.example.com' +// with a cookie domain of '.example.com', and the server updates the cookie domain to 'api.example.com', +// the client may retain cookies with both '.example.com' and 'api.example.com' domains. +// +// Consequently, if the server chooses the older cookie, session invalidation occurs, potentially +// resulting in an infinite refresh loop. To fix this, users are asked to specify "OlderCookieDomain" in +// the config. +// +// This function checks for multiple cookies with the same name and clears the cookies for the older domain. +func ClearSessionCookiesFromOlderCookieDomain(req *http.Request, res http.ResponseWriter, config sessmodels.TypeNormalisedInput, userContext supertokens.UserContext) error { + allowedTransferMethod := config.GetTokenTransferMethod(req, false, userContext) + + // If the transfer method is 'header', there's no need to clear cookies immediately, even if there are multiple in the request. + if allowedTransferMethod == sessmodels.HeaderTransferMethod { + return nil + } + + didClearCookies := false + + tokenTypes := []sessmodels.TokenType{sessmodels.AccessToken, sessmodels.RefreshToken} + for _, token := range tokenTypes { + if hasMultipleCookiesForTokenType(req, token) { + // If a request has multiple session cookies and 'olderCookieDomain' is + // unset, we can't identify the correct cookie for refreshing the session. + // Using the wrong cookie can cause an infinite refresh loop. To avoid this, + // we throw a 500 error asking the user to set 'olderCookieDomain'. + if config.OlderCookieDomain == nil { + return errors.New(`The request contains multiple session cookies. This may happen if you've changed the 'cookieDomain' value in your configuration. To clear tokens from the previous domain, set 'olderCookieDomain' in your config.`) + } + + supertokens.LogDebugMessage(fmt.Sprint("ClearSessionCookiesFromOlderCookieDomain: Clearing duplicate ", token, " cookie with domain ", config.OlderCookieDomain)) + config.CookieDomain = config.OlderCookieDomain + setToken(config, res, token, "", 0, sessmodels.CookieTransferMethod, req, userContext) + + didClearCookies = true + } + } + + if didClearCookies { + return sessionError.ClearDuplicateSessionCookiesError{ + Msg: "The request contains multiple session cookies. We are clearing the cookie from OlderCookieDomain. Session will be refreshed in the next refresh call.", + } + } + + return nil +} + +func hasMultipleCookiesForTokenType(req *http.Request, tokenType sessmodels.TokenType) bool { + // Count of cookies with the specified token type + count := 0 + + // Loop through each cookie in the request + for _, cookie := range req.Cookies() { + // Check if the cookie's name matches the token type + cookieName, _ := getCookieNameFromTokenType(tokenType) + if cookie.Name == cookieName { + count++ + } + } + + // If count is greater than 1, then there are multiple cookies with the given token type + return count > 1 +} diff --git a/recipe/session/errors/errors.go b/recipe/session/errors/errors.go index 402e7f2f..784947a2 100644 --- a/recipe/session/errors/errors.go +++ b/recipe/session/errors/errors.go @@ -18,10 +18,11 @@ package errors import "github.com/supertokens/supertokens-golang/recipe/session/claims" const ( - UnauthorizedErrorStr = "UNAUTHORISED" - TryRefreshTokenErrorStr = "TRY_REFRESH_TOKEN" - TokenTheftDetectedErrorStr = "TOKEN_THEFT_DETECTED" - InvalidClaimsErrorStr = "INVALID_CLAIMS" + UnauthorizedErrorStr = "UNAUTHORISED" + TryRefreshTokenErrorStr = "TRY_REFRESH_TOKEN" + TokenTheftDetectedErrorStr = "TOKEN_THEFT_DETECTED" + InvalidClaimsErrorStr = "INVALID_CLAIMS" + ClearDuplicateSessionCookiesErrorStr = "CLEAR_DUPLICATE_SESSION_COOKIES" ) // TryRefreshTokenError used for when the refresh API needs to be called @@ -66,3 +67,11 @@ type InvalidClaimError struct { func (err InvalidClaimError) Error() string { return err.Msg } + +type ClearDuplicateSessionCookiesError struct { + Msg string +} + +func (err ClearDuplicateSessionCookiesError) Error() string { + return err.Msg +} diff --git a/recipe/session/recipe.go b/recipe/session/recipe.go index 68dca1a2..03e5b39b 100644 --- a/recipe/session/recipe.go +++ b/recipe/session/recipe.go @@ -204,6 +204,13 @@ func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWrit supertokens.LogDebugMessage("errorHandler: returning INVALID_CLAIMS") errs := err.(errors.InvalidClaimError) return true, r.Config.ErrorHandlers.OnInvalidClaim(errs.InvalidClaims, req, res) + } else if defaultErrors.As(err, &errors.ClearDuplicateSessionCookiesError{}) { + supertokens.LogDebugMessage("errorHandler: returning CLEAR_DUPLICATE_SESSION_COOKIES") + // This error occurs in the `refreshPOST` API when multiple session + // cookies are found in the request and the user has set `olderCookieDomain`. + // We remove session cookies from the olderCookieDomain. The response must return `200 OK` + // to avoid logging out the user, allowing the session to continue with the valid cookie. + return true, r.Config.ErrorHandlers.OnClearDuplicateSessionCookies(err.Error(), req, res) } else { return r.OpenIdRecipe.RecipeModule.HandleError(err, req, res, userContext) } diff --git a/recipe/session/sessionErrorHandlers_test.go b/recipe/session/sessionErrorHandlers_test.go new file mode 100644 index 00000000..f839980f --- /dev/null +++ b/recipe/session/sessionErrorHandlers_test.go @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package session + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/supertokens/supertokens-golang/recipe/session/claims" + sessionErrors "github.com/supertokens/supertokens-golang/recipe/session/errors" + "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" + "github.com/supertokens/supertokens-golang/supertokens" + "github.com/supertokens/supertokens-golang/test/unittesting" + + "github.com/stretchr/testify/assert" +) + +func TestSessionErrorHandlerOverides(t *testing.T) { + BeforeEach() + + customAntiCsrfVal := "VIA_TOKEN" + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&sessmodels.TypeInput{ + AntiCsrf: &customAntiCsrfVal, + ErrorHandlers: &sessmodels.ErrorHandlers{ + OnUnauthorised: func(message string, req *http.Request, res http.ResponseWriter) error { + res.WriteHeader(401) + res.Write([]byte("unauthorised from errorHandler")) + return nil + }, + OnTokenTheftDetected: func(sessionHandle, userID string, req *http.Request, res http.ResponseWriter) error { + res.WriteHeader(403) + res.Write([]byte("token theft detected from errorHandler")) + return nil + }, + OnTryRefreshToken: func(message string, req *http.Request, res http.ResponseWriter) error { + res.WriteHeader(401) + res.Write([]byte("try refresh token from errorHandler")) + return nil + }, + OnInvalidClaim: func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error { + res.WriteHeader(403) + res.Write([]byte("invalid claim from errorHandler")) + return nil + }, + OnClearDuplicateSessionCookies: func(message string, req *http.Request, res http.ResponseWriter) error { + res.WriteHeader(200) + res.Write([]byte("clear duplicate session cookies from errorHandler")) + return nil + }, + }, + GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod { + return sessmodels.CookieTransferMethod + }, + }), + }, + } + + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + + mux := http.NewServeMux() + + mux.HandleFunc("/test/unauthorized", func(rw http.ResponseWriter, r *http.Request) { + supertokens.ErrorHandler(sessionErrors.UnauthorizedError{}, r, rw) + }) + + mux.HandleFunc("/test/try-refresh", func(rw http.ResponseWriter, r *http.Request) { + supertokens.ErrorHandler(sessionErrors.TryRefreshTokenError{}, r, rw) + }) + + mux.HandleFunc("/test/token-theft", func(rw http.ResponseWriter, r *http.Request) { + supertokens.ErrorHandler(sessionErrors.TokenTheftDetectedError{}, r, rw) + }) + + mux.HandleFunc("/test/claim-validation", func(rw http.ResponseWriter, r *http.Request) { + supertokens.ErrorHandler(sessionErrors.InvalidClaimError{}, r, rw) + }) + + mux.HandleFunc("/test/clear-duplicate-session", func(rw http.ResponseWriter, r *http.Request) { + supertokens.ErrorHandler(sessionErrors.ClearDuplicateSessionCookiesError{}, r, rw) + }) + + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer func() { + testServer.Close() + }() + + t.Run("should override session errorHandlers", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, testServer.URL+"/test/unauthorized", nil) + assert.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 401, res.StatusCode) + + content, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"unauthorised from errorHandler"}`, string(content)) + + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/try-refresh", nil) + assert.NoError(t, err) + + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 401, res.StatusCode) + + content, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"try refresh token from errorHandler"}`, string(content)) + + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/token-theft", nil) + assert.NoError(t, err) + + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 403, res.StatusCode) + + content, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"token theft detected from errorHandler"}`, string(content)) + + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/claim-validation", nil) + assert.NoError(t, err) + + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 403, res.StatusCode) + + content, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"invalid claim from errorHandler"}`, string(content)) + + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/test/clear-duplicate-session", nil) + assert.NoError(t, err) + + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) + + content, err = io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"clear duplicate session cookies from errorHandler"}`, string(content)) + }) +} diff --git a/recipe/session/sessionRequestFunctions.go b/recipe/session/sessionRequestFunctions.go index be796f27..e8044d56 100644 --- a/recipe/session/sessionRequestFunctions.go +++ b/recipe/session/sessionRequestFunctions.go @@ -182,6 +182,18 @@ func GetSessionFromRequest(req *http.Request, res http.ResponseWriter, config se accessToken = accessTokens[sessmodels.HeaderTransferMethod] } else if (allowedTokenTransferMethod == sessmodels.AnyTransferMethod || allowedTokenTransferMethod == sessmodels.CookieTransferMethod) && (accessTokens[sessmodels.CookieTransferMethod] != nil) { supertokens.LogDebugMessage("getSession: using cookie transfer method") + + // If multiple access tokens exist in the request cookie, throw TRY_REFRESH_TOKEN. + // This prompts the client to call the refresh endpoint, clearing olderCookieDomain cookies (if set). + // ensuring outdated token payload isn't used. + if hasMultipleCookiesForTokenType(req, sessmodels.AccessToken) { + supertokens.LogDebugMessage("getSession: Throwing TRY_REFRESH_TOKEN because multiple access tokens are present in request cookies") + + return nil, errors.TryRefreshTokenError{ + Msg: "Multiple access tokens present in the request cookies.", + } + } + cookieMethod := sessmodels.CookieTransferMethod requestTokenTransferMethod = &cookieMethod accessToken = accessTokens[sessmodels.CookieTransferMethod] @@ -310,6 +322,11 @@ func GetSessionFromRequest(req *http.Request, res http.ResponseWriter, config se func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config sessmodels.TypeNormalisedInput, recipeImpl sessmodels.RecipeInterface, userContext supertokens.UserContext) (sessmodels.SessionContainer, error) { supertokens.LogDebugMessage("refreshSession: Started") + err := ClearSessionCookiesFromOlderCookieDomain(req, res, config, userContext) + if err != nil { + return nil, err + } + refreshTokens := map[sessmodels.TokenTransferMethod]*string{} // We check all token transfer methods for available refresh tokens // We do this so that we can later clear all we are not overwriting @@ -344,7 +361,28 @@ func RefreshSessionInRequest(req *http.Request, res http.ResponseWriter, config setCookie(config, res, legacyIdRefreshTokenCookieName, "", 0, "accessTokenPath", req, userContext) } - supertokens.LogDebugMessage("refreshSession: UNAUTHORISED because refresh token in request is undefined") + // We need to clear the access token cookie if + // - the refresh token is not found, and + // - the allowedTransferMethod is 'cookie' or 'any', and + // - an access token cookie exists (otherwise it'd be a no-op) + // See: https://github.com/supertokens/supertokens-node/issues/790 + token, err := GetToken(req, sessmodels.AccessToken, sessmodels.CookieTransferMethod) + if err != nil { + return nil, err + } + if (allowedTokenTransferMethod == sessmodels.AnyTransferMethod || allowedTokenTransferMethod == sessmodels.CookieTransferMethod) && token != nil { + supertokens.LogDebugMessage("refreshSession: cleared all session tokens and returning UNAUTHORISED because refresh token in request is undefined") + + // We're clearing all session tokens instead of just the access token and then throwing an UNAUTHORISED + // error with `ClearTokens: True`. This approach avoids confusion and we don't want to retain session + // tokens on the client in any case if the refresh API is called without a refresh token but with an access token. + True := true + return nil, errors.UnauthorizedError{ + Msg: "Refresh token not found but access token is present. Clearing all tokens.", + ClearTokens: &True, + } + } + False := false return nil, errors.UnauthorizedError{ Msg: "Refresh token not found. Are you sending the refresh token in the request as a cookie?", diff --git a/recipe/session/sessionUtils_test.go b/recipe/session/sessionUtils_test.go new file mode 100644 index 00000000..fdcf9387 --- /dev/null +++ b/recipe/session/sessionUtils_test.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved. + * + * This software is licensed under the Apache License, Version 2.0 (the + * "License") as published by the Apache Software Foundation. + * + * You may not use this file except in compliance with the License. You may + * obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package session + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormaliseSessionScope(t *testing.T) { + t.Run("test with empty string", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("") + assert.NoError(t, err) + assert.Equal(t, "", *result) + }) + + t.Run("test with leading dot", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError(".example.com") + assert.NoError(t, err) + assert.Equal(t, ".example.com", *result) + }) + + t.Run("test without leading dot", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("example.com") + assert.NoError(t, err) + assert.Equal(t, "example.com", *result) + }) + + t.Run("test with http prefix", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("http://example.com") + assert.NoError(t, err) + assert.Equal(t, "example.com", *result) + }) + + t.Run("test with https prefix", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("https://example.com") + assert.NoError(t, err) + assert.Equal(t, "example.com", *result) + }) + + t.Run("test with IP address", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("192.168.1.1") + assert.NoError(t, err) + assert.Equal(t, "192.168.1.1", *result) + }) + + t.Run("test with localhost", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("localhost") + assert.NoError(t, err) + assert.Equal(t, "localhost", *result) + }) + + t.Run("test with leading and trailing whitespace", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError(" example.com ") + assert.NoError(t, err) + assert.Equal(t, "example.com", *result) + }) + + t.Run("test with subdomain", func(t *testing.T) { + result, err := normaliseSessionScopeOrThrowError("sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError("http://sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError("https://sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError(".sub.example.com") + assert.NoError(t, err) + assert.Equal(t, ".sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError("a.sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "a.sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError("http://a.sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "a.sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError("https://a.sub.example.com") + assert.NoError(t, err) + assert.Equal(t, "a.sub.example.com", *result) + + result, err = normaliseSessionScopeOrThrowError(".a.sub.example.com") + assert.NoError(t, err) + assert.Equal(t, ".a.sub.example.com", *result) + }) +} diff --git a/recipe/session/session_test.go b/recipe/session/session_test.go index 62fa5012..659726a4 100644 --- a/recipe/session/session_test.go +++ b/recipe/session/session_test.go @@ -18,6 +18,7 @@ package session import ( "encoding/json" + "io" "net/http" "net/http/httptest" "strings" @@ -32,49 +33,65 @@ import ( "github.com/supertokens/supertokens-golang/test/unittesting" ) -func TestOutputHeadersAndSetCookieForCreateSessionIsFine(t *testing.T) { - customAntiCsrfVal := "VIA_TOKEN" - configValue := supertokens.TypeInput{ - Supertokens: &supertokens.ConnectionInfo{ - ConnectionURI: "http://localhost:8080", - }, - AppInfo: supertokens.AppInfo{ - AppName: "SuperTokens", - WebsiteDomain: "supertokens.io", - APIDomain: "api.supertokens.io", - }, - RecipeList: []supertokens.Recipe{ - Init(&sessmodels.TypeInput{ - AntiCsrf: &customAntiCsrfVal, - GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod { - return sessmodels.CookieTransferMethod - }, - }), - }, - } +func TestCookieBasedAuth(t *testing.T) { BeforeEach() unittesting.StartUpST("localhost", "8080") defer AfterEach() - err := supertokens.Init(configValue) - if err != nil { - t.Error(err.Error()) + + cfgVal := func(tokenTransferMethod sessmodels.TokenTransferMethod, olderCookieDomain *string) supertokens.TypeInput { + customAntiCsrfVal := "VIA_TOKEN" + return supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&sessmodels.TypeInput{ + OlderCookieDomain: olderCookieDomain, + AntiCsrf: &customAntiCsrfVal, + GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod { + return tokenTransferMethod + }, + }), + }, + } } - mux := http.NewServeMux() + err := supertokens.Init(cfgVal(sessmodels.CookieTransferMethod, nil)) + assert.NoError(t, err) + mux := http.NewServeMux() mux.HandleFunc("/create", func(rw http.ResponseWriter, r *http.Request) { CreateNewSession(r, rw, "public", "rope", map[string]interface{}{}, map[string]interface{}{}) }) + customValForAntiCsrfCheck := true + customSessionRequiredValue := true + mux.HandleFunc("/verifySession", VerifySession(&sessmodels.VerifySessionOptions{ + SessionRequired: &customSessionRequiredValue, + AntiCsrfCheck: &customValForAntiCsrfCheck, + }, func(rw http.ResponseWriter, r *http.Request) { + GetSession(r, rw, &sessmodels.VerifySessionOptions{ + SessionRequired: &customSessionRequiredValue, + AntiCsrfCheck: &customValForAntiCsrfCheck, + }) + })) + testServer := httptest.NewServer(supertokens.Middleware(mux)) defer func() { testServer.Close() }() + req, err := http.NewRequest(http.MethodGet, testServer.URL+"/create", nil) assert.NoError(t, err) res, err := http.DefaultClient.Do(req) assert.NoError(t, err) cookieData := unittesting.ExtractInfoFromResponse(res) + assert.Equal(t, []string{"front-token, anti-csrf"}, res.Header["Access-Control-Expose-Headers"]) assert.Equal(t, "", cookieData["refreshTokenDomain"]) assert.Equal(t, "", cookieData["accessTokenDomain"]) @@ -83,6 +100,177 @@ func TestOutputHeadersAndSetCookieForCreateSessionIsFine(t *testing.T) { assert.NotNil(t, cookieData["antiCsrf"]) assert.NotNil(t, cookieData["accessTokenExpiry"]) assert.NotNil(t, cookieData["refreshTokenExpiry"]) + + t.Run("verifySession returns 401 if multiple tokens are passed in the request", func(t *testing.T) { + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/verifySession", nil) + assert.NoError(t, err) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("Cookie", "sRefreshToken="+cookieData["sRefreshToken"]) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("Cookie", "sRefreshToken="+cookieData["sRefreshToken"]) + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + + content, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `{"message":"try refresh token"}`, string(content)) + }) + + t.Run("refresh endpoint throws a 500 if multiple tokens are passed and olderCookieDomain is undefined", func(t *testing.T) { + req, err = http.NewRequest(http.MethodPost, testServer.URL+"/auth/session/refresh", nil) + assert.NoError(t, err) + req.Header.Add("Cookie", "sAccessToken=accessToken1") + req.Header.Add("Cookie", "sAccessToken=accessToken2") + req.Header.Add("Cookie", "sRefreshToken=refreshToken1") + req.Header.Add("Cookie", "sRefreshToken=refreshToken2") + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + cookieData = unittesting.ExtractInfoFromResponse(res) + + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + content, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, "The request contains multiple session cookies. This may happen if you've changed the 'cookieDomain' value in your configuration. To clear tokens from the previous domain, set 'olderCookieDomain' in your config.\n", string(content)) + }) + + t.Run("all session tokens are cleared if refresh token api is called without the refresh token but with the access token", func(t *testing.T) { + req, err = http.NewRequest(http.MethodPost, testServer.URL+"/auth/session/refresh", nil) + assert.NoError(t, err) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("anti-csrf", cookieData["antiCsrf"]) + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + cookieData = unittesting.ExtractInfoFromResponse(res) + + assert.Equal(t, http.StatusUnauthorized, res.StatusCode) + assert.Empty(t, cookieData["sAccessToken"]) + assert.Equal(t, cookieData["accessTokenExpiry"], "Thu, 01 Jan 1970 00:00:00 GMT") + assert.Empty(t, cookieData["sRefreshToken"]) + assert.Equal(t, cookieData["refreshTokenExpiry"], "Thu, 01 Jan 1970 00:00:00 GMT") + }) + + resetAll() + olderCookieName := ".example.com" + err = supertokens.Init(cfgVal(sessmodels.CookieTransferMethod, &olderCookieName)) + assert.NoError(t, err) + + mux = http.NewServeMux() + mux.HandleFunc("/create", func(rw http.ResponseWriter, r *http.Request) { + CreateNewSession(r, rw, "public", "rope", map[string]interface{}{}, map[string]interface{}{}) + }) + + testServer = httptest.NewServer(supertokens.Middleware(mux)) + + t.Run("access and refresh token for olderCookieDomain is cleared if multiple tokens are passed to the refresh endpoint", func(t *testing.T) { + req, err = http.NewRequest(http.MethodPost, testServer.URL+"/auth/session/refresh", nil) + assert.NoError(t, err) + req.Header.Add("Cookie", "sAccessToken=accessToken1") + req.Header.Add("Cookie", "sAccessToken=accessToken2") + req.Header.Add("Cookie", "sRefreshToken=refreshToken1") + req.Header.Add("Cookie", "sRefreshToken=refreshToken2") + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + cookieData = unittesting.ExtractInfoFromResponse(res) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Empty(t, cookieData["sAccessToken"]) + assert.Equal(t, "Thu, 01 Jan 1970 00:00:00 GMT", cookieData["accessTokenExpiry"]) + assert.Equal(t, "example.com", cookieData["accessTokenDomain"]) // TODO: node sdk returns .example.com + assert.Empty(t, cookieData["sRefreshToken"]) + assert.Equal(t, "Thu, 01 Jan 1970 00:00:00 GMT", cookieData["refreshTokenExpiry"]) + assert.Equal(t, "example.com", cookieData["refreshTokenDomain"]) // TODO: node sdk returns .example.com + }) +} + +func TestHeaderBasedAuthAndMultipleTokensInCookies(t *testing.T) { + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + + cfgVal := func(tokenTransferMethod sessmodels.TokenTransferMethod, olderCookieDomain *string) supertokens.TypeInput { + customAntiCsrfVal := "VIA_TOKEN" + return supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&sessmodels.TypeInput{ + OlderCookieDomain: olderCookieDomain, + AntiCsrf: &customAntiCsrfVal, + GetTokenTransferMethod: func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) sessmodels.TokenTransferMethod { + return tokenTransferMethod + }, + }), + }, + } + } + + err := supertokens.Init(cfgVal(sessmodels.HeaderTransferMethod, nil)) + assert.NoError(t, err) + + mux := http.NewServeMux() + mux.HandleFunc("/create", func(rw http.ResponseWriter, r *http.Request) { + CreateNewSession(r, rw, "public", "testuserid", map[string]interface{}{}, map[string]interface{}{}) + }) + + mux.HandleFunc("/verifySession", func(writer http.ResponseWriter, request *http.Request) { + sessionResponse, _ := GetSession(request, writer, nil) + userID := sessionResponse.GetUserID() + writer.WriteHeader(http.StatusOK) + writer.Write([]byte(userID)) + }) + + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer func() { + testServer.Close() + }() + + req, err := http.NewRequest(http.MethodGet, testServer.URL+"/create", nil) + assert.NoError(t, err) + res, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + cookieData := unittesting.ExtractInfoFromResponse(res) + + t.Run("verifySession returns 200 in header based auth even if multiple tokens are present in the cookie", func(t *testing.T) { + req, err = http.NewRequest(http.MethodGet, testServer.URL+"/verifySession", nil) + assert.NoError(t, err) + req.Header.Add("Authorization", "Bearer "+cookieData["accessTokenFromHeader"]) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, res.StatusCode) + + content, err := io.ReadAll(res.Body) + assert.NoError(t, err) + assert.Equal(t, `testuserid`, string(content)) + }) + + t.Run("refresh endpoint refreshes the token in header based auth even if multiple tokens are present in the cookie", func(t *testing.T) { + req, err = http.NewRequest(http.MethodPost, testServer.URL+"/auth/session/refresh", nil) + assert.NoError(t, err) + req.Header.Add("Authorization", "Bearer "+cookieData["refreshTokenFromHeader"]) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("Cookie", "sRefreshToken="+cookieData["sRefreshToken"]) + req.Header.Add("Cookie", "sAccessToken="+cookieData["sAccessToken"]) + req.Header.Add("Cookie", "sRefreshToken="+cookieData["sRefreshToken"]) + + res, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + cookieData := unittesting.ExtractInfoFromResponse(res) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.NotEmpty(t, cookieData["accessTokenFromHeader"]) + assert.NotEmpty(t, cookieData["refreshTokenFromHeader"]) + }) } func TestTokenTheftDetection(t *testing.T) { diff --git a/recipe/session/sessmodels/models.go b/recipe/session/sessmodels/models.go index 75e8b630..f088649b 100644 --- a/recipe/session/sessmodels/models.go +++ b/recipe/session/sessmodels/models.go @@ -104,6 +104,7 @@ type TypeInput struct { SessionExpiredStatusCode *int InvalidClaimStatusCode *int CookieDomain *string + OlderCookieDomain *string AntiCsrf *string Override *OverrideStruct ErrorHandlers *ErrorHandlers @@ -119,14 +120,17 @@ type OverrideStruct struct { } type ErrorHandlers struct { - OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error - OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error - OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error + OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error + OnTryRefreshToken func(message string, req *http.Request, res http.ResponseWriter) error + OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error + OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error + OnClearDuplicateSessionCookies func(message string, req *http.Request, res http.ResponseWriter) error } type TypeNormalisedInput struct { RefreshTokenPath supertokens.NormalisedURLPath CookieDomain *string + OlderCookieDomain *string GetCookieSameSite func(request *http.Request, userContext supertokens.UserContext) (string, error) CookieSecure bool SessionExpiredStatusCode int @@ -169,10 +173,11 @@ type APIOptions struct { } type NormalisedErrorHandlers struct { - OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error - OnTryRefreshToken func(message string, req *http.Request, res http.ResponseWriter) error - OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error - OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error + OnUnauthorised func(message string, req *http.Request, res http.ResponseWriter) error + OnTryRefreshToken func(message string, req *http.Request, res http.ResponseWriter) error + OnTokenTheftDetected func(sessionHandle string, userID string, req *http.Request, res http.ResponseWriter) error + OnInvalidClaim func(validationErrors []claims.ClaimValidationError, req *http.Request, res http.ResponseWriter) error + OnClearDuplicateSessionCookies func(message string, req *http.Request, res http.ResponseWriter) error } type SessionTokens struct { diff --git a/recipe/session/utils.go b/recipe/session/utils.go index 01a302e1..0cfca6db 100644 --- a/recipe/session/utils.go +++ b/recipe/session/utils.go @@ -46,8 +46,9 @@ func GetRequiredClaimValidators( func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config *sessmodels.TypeInput) (sessmodels.TypeNormalisedInput, error) { var ( - cookieDomain *string = nil - err error + cookieDomain *string = nil + olderCookieDomain *string = nil + err error ) if config != nil && config.CookieDomain != nil { @@ -57,6 +58,13 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config } } + if config != nil && config.OlderCookieDomain != nil { + olderCookieDomain, err = normaliseSessionScopeOrThrowError(*config.OlderCookieDomain) + if err != nil { + return sessmodels.TypeNormalisedInput{}, err + } + } + if config != nil && config.CookieSameSite != nil { // we have this block just to check if the user input is correct _, err = normaliseSameSiteOrThrowError(*config.CookieSameSite) @@ -172,6 +180,9 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config } return sendInvalidClaimResponse(*recipeInstance, validationErrors, req, res) }, + OnClearDuplicateSessionCookies: func(message string, req *http.Request, res http.ResponseWriter) error { + return supertokens.Send200Response(res, message) + }, } if config != nil && config.ErrorHandlers != nil { @@ -184,6 +195,12 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config if config.ErrorHandlers.OnInvalidClaim != nil { errorHandlers.OnInvalidClaim = config.ErrorHandlers.OnInvalidClaim } + if config.ErrorHandlers.OnTryRefreshToken != nil { + errorHandlers.OnTryRefreshToken = config.ErrorHandlers.OnTryRefreshToken + } + if config.ErrorHandlers.OnClearDuplicateSessionCookies != nil { + errorHandlers.OnClearDuplicateSessionCookies = config.ErrorHandlers.OnClearDuplicateSessionCookies + } } refreshAPIPath, err := supertokens.NewNormalisedURLPath(RefreshAPIPath) @@ -208,6 +225,7 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config typeNormalisedInput := sessmodels.TypeNormalisedInput{ RefreshTokenPath: appInfo.APIBasePath.AppendPath(refreshAPIPath), CookieDomain: cookieDomain, + OlderCookieDomain: olderCookieDomain, GetCookieSameSite: cookieSameSite, CookieSecure: cookieSecure, SessionExpiredStatusCode: sessionExpiredStatusCode, @@ -259,35 +277,45 @@ func GetURLScheme(URL string) (string, error) { } func normaliseSessionScopeOrThrowError(sessionScope string) (*string, error) { - sessionScope = strings.TrimSpace(sessionScope) - sessionScope = strings.ToLower(sessionScope) + helper := func(scope string) (string, error) { + scope = strings.TrimSpace(scope) + scope = strings.ToLower(scope) + + scope = strings.TrimPrefix(scope, ".") + + if !strings.HasPrefix(scope, "http://") && !strings.HasPrefix(scope, "https://") { + scope = "http://" + scope + } - sessionScope = strings.TrimPrefix(sessionScope, ".") + parsedURL, err := url.Parse(scope) + if err != nil { + return "", errors.New("please provide a valid sessionScope") + } - if !strings.HasPrefix(sessionScope, "http://") && !strings.HasPrefix(sessionScope, "https://") { - sessionScope = "http://" + sessionScope + hostname := parsedURL.Hostname() + + return hostname, nil } - urlObj, err := url.Parse(sessionScope) + noDotNormalised, err := helper(sessionScope) if err != nil { - return nil, errors.New("Please provide a valid sessionScope") + return nil, err } - sessionScope = urlObj.Hostname() - sessionScope = strings.TrimPrefix(sessionScope, ".") - - noDotNormalised := sessionScope - isAnIP, err := supertokens.IsAnIPAddress(sessionScope) if err != nil { return nil, err } - if sessionScope == "localhost" || isAnIP { - noDotNormalised = sessionScope + + if noDotNormalised == "localhost" || isAnIP { + return &noDotNormalised, nil } + if strings.HasPrefix(sessionScope, ".") { - noDotNormalised = "." + sessionScope + noDotNormalised = "." + noDotNormalised + return &noDotNormalised, nil } + return &noDotNormalised, nil } diff --git a/supertokens/constants.go b/supertokens/constants.go index e3db5743..a62dfa7f 100644 --- a/supertokens/constants.go +++ b/supertokens/constants.go @@ -21,7 +21,7 @@ const ( ) // VERSION current version of the lib -const VERSION = "0.18.0" +const VERSION = "0.19.0" var ( cdiSupported = []string{"3.0"} diff --git a/test/unittesting/testingutils.go b/test/unittesting/testingutils.go index 33f87547..db74e06c 100644 --- a/test/unittesting/testingutils.go +++ b/test/unittesting/testingutils.go @@ -252,9 +252,12 @@ func ExtractInfoFromResponse(res *http.Response) map[string]string { } for _, property := range strings.Split(cookie, ";") { + if strings.HasPrefix(property, " Domain=") { + refreshTokenDomain = strings.TrimPrefix(property, " Domain=") + } + if strings.Index(property, "HttpOnly") == 1 { refreshTokenHttpOnly = "true" - break } } } else if strings.Split(strings.Split(cookie, ";")[0], "=")[0] == "sAccessToken" { @@ -269,9 +272,12 @@ func ExtractInfoFromResponse(res *http.Response) map[string]string { if strings.Split(strings.Split(cookie, ";")[1], "=")[0] == " Path" { } for _, property := range strings.Split(cookie, ";") { + if strings.HasPrefix(property, " Domain=") { + accessTokenDomain = strings.TrimPrefix(property, " Domain=") + } + if strings.Index(property, "HttpOnly") == 1 { accessTokenHttpOnly = "true" - break } } }