-
Notifications
You must be signed in to change notification settings - Fork 2
/
client_test.go
351 lines (305 loc) · 11 KB
/
client_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package fireboltgosdk
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
"time"
)
func init() {
originalEndpoint = os.Getenv("FIREBOLT_ENDPOINT")
}
var originalEndpoint string
func raiseIfError(t *testing.T, err error) {
if err != nil {
t.Errorf("Encountered error %s", err)
}
}
// TestCacheAccessToken tests that a token is cached during authentication and reused for subsequent requests
func TestCacheAccessToken(t *testing.T) {
var fetchTokenCount = 0
var totalCount = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == ServiceAccountLoginURLSuffix {
fetchTokenCount++
_, _ = w.Write(getAuthResponse(10000))
} else {
w.WriteHeader(http.StatusOK)
}
totalCount++
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
for i := 0; i < 3; i++ {
resp := client.request(context.TODO(), "GET", server.URL, nil, "")
raiseIfError(t, resp.err)
}
token, _ := getAccessTokenServiceAccount("client_id", "", server.URL, "")
if token != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
}
if getCachedAccessToken("client_id", server.URL) != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
}
if fetchTokenCount != 1 {
t.Errorf("Did not fetch token only once. Total: %d", fetchTokenCount)
}
if totalCount != 4 {
t.Errorf("Expected to call the server 4 times (1x to fetch token and 3x to send another request). Total: %d", totalCount)
}
}
// TestRefreshTokenOn401 tests that a token is refreshed when the server returns a 401
func TestRefreshTokenOn401(t *testing.T) {
var fetchTokenCount = 0
var totalCount = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == ServiceAccountLoginURLSuffix {
fetchTokenCount++
_, _ = w.Write(getAuthResponse(10000))
} else {
w.WriteHeader(http.StatusUnauthorized)
}
totalCount++
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_ = client.request(context.TODO(), "GET", server.URL, nil, "")
if getCachedAccessToken("client_id", server.URL) != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
}
if fetchTokenCount != 2 {
// The token should be fetched twice as it is removed from the cache due to the 403 and then fetched again
t.Errorf("Did not fetch token twice. Total: %d", fetchTokenCount)
}
if totalCount != 4 {
// The token is fetched twice and the request is retried
t.Errorf("Expected to call the server 4 times (2x to fetch tokens and 2x to send the request that returns a 403). Total: %d", totalCount)
}
}
// TestFetchTokenWhenExpired tests that a new token is fetched upon expiry
func TestFetchTokenWhenExpired(t *testing.T) {
var fetchTokenCount = 0
var totalCount = 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == ServiceAccountLoginURLSuffix {
fetchTokenCount++
_, _ = w.Write(getAuthResponse(1))
} else {
w.WriteHeader(http.StatusOK)
}
totalCount++
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_ = client.request(context.TODO(), "GET", server.URL, nil, "")
// Waiting for the token to get expired
time.Sleep(2 * time.Millisecond)
_ = client.request(context.TODO(), "GET", server.URL, nil, "")
token, _ := getAccessTokenUsernamePassword("client_id", "", server.URL, "")
if token != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
}
if getCachedAccessToken("client_id", server.URL) != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
}
if fetchTokenCount != 2 {
// The token should be fetched twice as it is automatically removed from the cache because it is expired
t.Errorf("Did not fetch token twice. Total: %d", fetchTokenCount)
}
if totalCount != 4 {
t.Errorf("Expected to call the server 5 times (2x to fetch tokens and 3x to send the request that returns a 403). Total: %d", totalCount)
}
}
// TestUserAgent tests that UserAgent is correctly set on request
func TestUserAgent(t *testing.T) {
var userAgentValue = "userAgent"
var userAgentHeader = ""
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
userAgentHeader = r.Header.Get("User-Agent")
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: userAgentValue},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams
_, _ = client.Query(context.TODO(), server.URL, "SELECT 1", map[string]string{}, connectionControl{})
if userAgentHeader != userAgentValue {
t.Errorf("Did not set User-Agent value correctly on a query request")
}
}
func clientFactory(apiEndpoint string) Client {
var client = &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: apiEndpoint},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams
err := initialiseCaches()
if err != nil {
log.Printf("Error while initializing caches: %s", err)
}
return client
}
// TestProtocolVersion tests that protocol version is correctly set on request
func TestProtocolVersion(t *testing.T) {
testProtocolVersion(t, clientFactory)
}
// TestUpdateParameters tests that update parameters are correctly set on request
func TestUpdateParameters(t *testing.T) {
testUpdateParameters(t, clientFactory)
}
func getAuthResponse(expiry int) []byte {
var response = `{
"access_token": "aMysteriousToken",
"refresh_token": "refresh",
"scope": "offline_access",
"expires_in": ` + strconv.Itoa(expiry) + `,
"token_type": "Bearer"
}`
return []byte(response)
}
func setupTestServerAndClient(t *testing.T, testAccountName string) (*httptest.Server, *ClientImpl) {
// Create a mock server that returns a 404 status code
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) {
rw.WriteHeader(http.StatusNotFound)
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))
prepareEnvVariablesForTest(t, server)
client := &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams
return server, client
}
func TestGetSystemEngineURLReturnsErrorOn404(t *testing.T) {
testAccountName := "testAccount"
server, client := setupTestServerAndClient(t, testAccountName)
defer server.Close()
// Call the getSystemEngineURL method and check if it returns an error
_, _, err := client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
if err == nil {
t.Errorf("Expected an error, got nil")
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) {
t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error())
}
}
func TestGetSystemEngineURLCaching(t *testing.T) {
testAccountName := "testAccount"
urlCalled := 0
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"engineUrl": "https://my_url.com"}`))
urlCalled++
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = clientFactory(server.URL).(*ClientImpl)
var err error
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
// Create a new client
client = clientFactory(server.URL).(*ClientImpl)
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
// Still only one call, as the cache is shared between clients
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
}
func TestUpdateEndpoint(t *testing.T) {
var newEndpoint = "new-endpoint/path?query=param"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == ServiceAccountLoginURLSuffix {
_, _ = w.Write(getAuthResponse(10000))
} else if r.URL.Path == UsernamePasswordURLSuffix {
_, _ = w.Write(getAuthResponseV0(10000))
} else {
w.Header().Set(updateEndpointHeader, newEndpoint)
w.WriteHeader(http.StatusOK)
}
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
client := clientFactory(server.URL)
params := map[string]string{
"database": "db",
}
engineEndpoint := "old-endpoint"
_, err := client.Query(context.TODO(), server.URL, "SELECT 1", params, connectionControl{
updateParameters: func(key, value string) {
params[key] = value
},
setEngineURL: func(value string) {
engineEndpoint = value
},
})
raiseIfError(t, err)
if params["query"] != "param" {
t.Errorf("Query parameter was not set correctly. Expected 'param' but was %s", params["query"])
}
expectedEndpoint := "new-endpoint/path"
if engineEndpoint != expectedEndpoint {
t.Errorf("Engine endpoint was not set correctly. Expected %s but was %s", expectedEndpoint, engineEndpoint)
}
}
func TestResetSession(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == ServiceAccountLoginURLSuffix {
_, _ = w.Write(getAuthResponse(10000))
} else {
w.Header().Set(resetSessionHeader, "true")
w.WriteHeader(http.StatusOK)
}
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)
client := clientFactory(server.URL)
resetCalled := false
params := map[string]string{
"database": "db",
}
_, err := client.Query(context.TODO(), server.URL, "SELECT 1", params, connectionControl{
resetParameters: func() {
resetCalled = true
},
})
raiseIfError(t, err)
if !resetCalled {
t.Errorf("Reset session was not called")
}
}
func TestAdditionalHeaders(t *testing.T) {
testAdditionalHeaders(t, clientFactory)
}