Skip to content

Commit

Permalink
Add test to make sure response headers work correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
nkshah2 committed Oct 20, 2023
1 parent 178929e commit 336548e
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 1 deletion.
10 changes: 9 additions & 1 deletion recipe/session/cookieAndHeaders.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,15 @@ func setHeader(res http.ResponseWriter, key, value string, allowDuplicateKey boo
if existingValue == "" {
res.Header().Set(key, value)
} else if allowDuplicateKey {
res.Header().Set(key, existingValue+", "+value)
/**
We only want to append if it does not already exist
For example if the caller is trying to add front token to the access control exposed headers property
we do not want to append if something else had already added it
*/
if !strings.Contains(existingValue, value) {
res.Header().Set(key, existingValue+", "+value)
}
} else {
res.Header().Set(key, value)
}
Expand Down
217 changes: 217 additions & 0 deletions recipe/session/verifySession_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -976,6 +977,151 @@ func TestThatAntiCSRFCheckIsSkippedIfSessionRequiredIsFalseAndNoAccessTokenIsPas
assert.Equal(t, res.StatusCode, 200)
}

func TestThatResponseHeadersAreCorrectWhenUsingCookies(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(&sessmodels.TypeInput{}),
},
}
BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}

app := getTestApp([]typeTestEndpoint{})
defer app.Close()

sessionCookies := createSessionWithCookies(app, map[string]interface{}{})
print(sessionCookies)

var accessToken string

for _, cookie := range sessionCookies {
if cookie.Name == "sAccessToken" {
accessToken = cookie.Value
}
}

assert.NotNil(t, accessToken)

req, err := http.NewRequest(http.MethodGet, app.URL+"/merge-payload", nil)

if err != nil {
t.Error(err.Error())
}

req.Header.Add("Cookie", "sAccessToken="+accessToken)

res, err := http.DefaultClient.Do(req)

if err != nil {
t.Error(err.Error())
}

cookiesHeaderValues := res.Header.Values("Set-Cookie")
accessTokenCount := 0

for _, cookieValue := range cookiesHeaderValues {
if strings.Contains(cookieValue, "sAccessToken") {
accessTokenCount += 1
}
}

assert.Equal(t, accessTokenCount, 1)

accessAllowHeaderValues := strings.Split(res.Header.Get("Access-Control-Expose-Headers"), ",")
frontTokenCount := 0

for _, value := range accessAllowHeaderValues {
if strings.Contains(value, "front-token") {
frontTokenCount += 1
}
}

assert.Equal(t, frontTokenCount, 1)
/**
Goland does not realise that the test passed because the start and end prints happen in the same line
This extra print adds a break line with the "--- PASS:" line being added in a new line making it clear
to the IDE that the test passed.
Leaving this in to avoid future confusion. Weirdly this happens intermittently for all tests.
*/
fmt.Println("")
}

func TestThatResponseHeadersAreCorrectWhenUsingHeaders(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(&sessmodels.TypeInput{}),
},
}
BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}

app := getTestApp([]typeTestEndpoint{})
defer app.Close()

headers := createSessionWithHeaders(app, map[string]interface{}{})
accessToken := headers.Get("st-access-token")

assert.NotNil(t, accessToken)

req, err := http.NewRequest(http.MethodGet, app.URL+"/merge-payload", nil)

if err != nil {
t.Error(err.Error())
}

req.Header.Add("Authorization", "Bearer "+accessToken)

res, err := http.DefaultClient.Do(req)

if err != nil {
t.Error(err.Error())
}

accessTokenHeaderValues := res.Header.Values("st-access-token")
accessTokenCount := len(accessTokenHeaderValues)

assert.Equal(t, accessTokenCount, 1)

accessAllowHeaderValues := strings.Split(res.Header.Get("Access-Control-Expose-Headers"), ",")
frontTokenCount := 0

for _, value := range accessAllowHeaderValues {
if strings.Contains(value, "front-token") {
frontTokenCount += 1
}
}

assert.Equal(t, frontTokenCount, 1)
}

type typeTestEndpoint struct {
path string
overrideGlobalClaimValidators func(globalClaimValidators []claims.SessionClaimValidator, sessionContainer sessmodels.SessionContainer, userContext supertokens.UserContext) ([]claims.SessionClaimValidator, error)
Expand All @@ -993,6 +1139,46 @@ func createSession(app *httptest.Server, body map[string]interface{}) []*http.Co
return res.Cookies()
}

func createSessionWithCookies(app *httptest.Server, body map[string]interface{}) []*http.Cookie {
bodyBytes := []byte("{}")
if body != nil {
bodyBytes, _ = json.Marshal(body)
}
req, err := http.NewRequest(http.MethodPost, app.URL+"/create", bytes.NewBuffer(bodyBytes))
if err != nil {
return nil
}

req.Header.Set("st-auth-mode", "cookie")

res, err := http.DefaultClient.Do(req)
if err != nil {
return nil
}

return res.Cookies()
}

func createSessionWithHeaders(app *httptest.Server, body map[string]interface{}) http.Header {
bodyBytes := []byte("{}")
if body != nil {
bodyBytes, _ = json.Marshal(body)
}
req, err := http.NewRequest(http.MethodPost, app.URL+"/create", bytes.NewBuffer(bodyBytes))
if err != nil {
return nil
}

req.Header.Set("st-auth-mode", "header")

res, err := http.DefaultClient.Do(req)
if err != nil {
return nil
}

return res.Header
}

func getTestApp(endpoints []typeTestEndpoint) *httptest.Server {
mux := http.NewServeMux()

Expand Down Expand Up @@ -1035,6 +1221,37 @@ func getTestApp(endpoints []typeTestEndpoint) *httptest.Server {
GetSession(r, rw, &sessmodels.VerifySessionOptions{})
}))

mux.HandleFunc("/merge-payload", VerifySession(&sessmodels.VerifySessionOptions{}, func(rw http.ResponseWriter, r *http.Request) {
session, err := GetSession(r, rw, &sessmodels.VerifySessionOptions{})

if err != nil {
rw.WriteHeader(500)
return
}

session.MergeIntoAccessTokenPayload(map[string]interface{}{
"lastUpdate": "123",
})
session.MergeIntoAccessTokenPayload(map[string]interface{}{
"lastUpdate": "456",
})
session.MergeIntoAccessTokenPayload(map[string]interface{}{
"lastUpdate": "789",
})

resp := map[string]interface{}{
"status": "OK",
}
respBytes, err := json.Marshal(resp)
if err != nil {
return
}
rw.Header().Set("Content-Type", "application/json")
rw.Header().Set("Content-Length", fmt.Sprintf("%d", (len(respBytes))))
rw.WriteHeader(http.StatusOK)
rw.Write(respBytes)
}))

mux.HandleFunc("/default-claims", VerifySession(nil, func(w http.ResponseWriter, r *http.Request) {
sessionContainer := GetSessionFromRequestContext(r.Context())
resp := map[string]interface{}{
Expand Down

0 comments on commit 336548e

Please sign in to comment.