Skip to content

Commit

Permalink
feat: Cache url and account info (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiurin authored Jun 4, 2024
1 parent dda746b commit 39fc87d
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 23 deletions.
78 changes: 69 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@ import (
"errors"
"fmt"
"strings"

"github.com/astaxie/beego/cache"
)

// Static caches on pacakge level
var AccountCache cache.Cache
var URLCache cache.Cache

type ClientImpl struct {
ConnectedToSystemEngine bool
AccountName string
Expand All @@ -24,6 +30,21 @@ const accountError = `account '%s' does not exist in this organization or is not
Please verify the account name and make sure your service account has the
correct RBAC permissions and is linked to a user`

func initialiseCaches() error {
var err error
if AccountCache == nil {
if AccountCache, err = cache.NewCache("memory", `{}`); err != nil {
return fmt.Errorf("could not create account cache: %v", err)
}
}
if URLCache == nil {
if URLCache, err = cache.NewCache("memory", `{}`); err != nil {
return fmt.Errorf("could not create url cache: %v", err)
}
}
return nil
}

func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, error) {
client := &ClientImpl{
BaseClient: BaseClient{
Expand All @@ -37,14 +58,17 @@ func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, er
client.parameterGetter = client.getQueryParams
client.accessTokenGetter = client.getAccessToken

if err := initialiseCaches(); err != nil {
infolog.Printf("Error during cache initialisation: %v", err)
}

var err error
client.AccountID, client.AccountVersion, err = client.getAccountInfo(context.Background(), settings.accountName)
if err != nil {
return nil, ConstructNestedError("error during getting account id", err)
}
return client, nil
}

func (c *ClientImpl) getEngineUrlStatusDBByName(ctx context.Context, engineName string, systemEngineUrl string) (string, string, string, error) {
infolog.Printf("Get info for engine '%s'", engineName)
engineSQL := fmt.Sprintf(engineInfoSQL, engineName)
Expand Down Expand Up @@ -84,6 +108,17 @@ func parseEngineInfoResponse(resp [][]interface{}) (string, string, string, erro
return engineUrl, status, dbName, nil
}

func constructParameters(databaseName string, queryParams map[string][]string) map[string]string {
parameters := make(map[string]string)
if len(databaseName) != 0 {
parameters["database"] = databaseName
}
for key, value := range queryParams {
parameters[key] = value[0]
}
return parameters
}

func (c *ClientImpl) getSystemEngineURLAndParameters(ctx context.Context, accountName string, databaseName string) (string, map[string]string, error) {
infolog.Printf("Get system engine URL for account '%s'", accountName)

Expand All @@ -92,6 +127,21 @@ func (c *ClientImpl) getSystemEngineURLAndParameters(ctx context.Context, accoun
}

url := fmt.Sprintf(c.ApiEndpoint+EngineUrlByAccountName, accountName)
// Check if the URL is in the cache
if URLCache != nil {
val := URLCache.Get(url)
if val != nil {
if systemEngineURLResponse, ok := val.(SystemEngineURLResponse); ok {
infolog.Printf("Resolved account %s to system engine URL %s from cache", accountName, systemEngineURLResponse.EngineUrl)
engineUrl, queryParams, err := splitEngineEndpoint(systemEngineURLResponse.EngineUrl)
if err != nil {
return "", nil, ConstructNestedError("error during splitting system engine URL", err)
}
parameters := constructParameters(databaseName, queryParams)
return engineUrl, parameters, nil
}
}
}

resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
Expand All @@ -105,24 +155,20 @@ func (c *ClientImpl) getSystemEngineURLAndParameters(ctx context.Context, accoun
if err := json.Unmarshal(resp.data, &systemEngineURLResponse); err != nil {
return "", nil, ConstructNestedError("error during unmarshalling system engine URL response", errors.New(string(resp.data)))
}
if URLCache != nil {
URLCache.Put(url, systemEngineURLResponse, 0) //nolint:errcheck
}
engineUrl, queryParams, err := splitEngineEndpoint(systemEngineURLResponse.EngineUrl)
if err != nil {
return "", nil, ConstructNestedError("error during splitting system engine URL", err)
}

parameters := make(map[string]string)
if len(databaseName) != 0 {
parameters["database"] = databaseName
}
for key, value := range queryParams {
parameters[key] = value[0]
}
parameters := constructParameters(databaseName, queryParams)

return engineUrl, parameters, nil
}

func (c *ClientImpl) getAccountInfo(ctx context.Context, accountName string) (string, int, error) {
infolog.Printf("Getting account ID for '%s'", accountName)

type AccountIdURLResponse struct {
Id string `json:"id"`
Expand All @@ -132,6 +178,17 @@ func (c *ClientImpl) getAccountInfo(ctx context.Context, accountName string) (st

url := fmt.Sprintf(c.ApiEndpoint+AccountInfoByAccountName, accountName)

if AccountCache != nil {
val := AccountCache.Get(url)
if val != nil {
if accountInfo, ok := val.(AccountIdURLResponse); ok {
infolog.Printf("Resolved account %s to id %s from cache", accountName, accountInfo.Id)
return accountInfo.Id, accountInfo.InfraVersion, nil
}
}
}
infolog.Printf("Getting account ID for '%s'", accountName)

resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
return "", 0, fmt.Errorf(accountError, accountName)
Expand All @@ -146,6 +203,9 @@ func (c *ClientImpl) getAccountInfo(ctx context.Context, accountName string) (st
if err := json.Unmarshal(resp.data, &accountIdURLResponse); err != nil {
return "", 0, ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(resp.data)))
}
if AccountCache != nil {
AccountCache.Put(url, accountIdURLResponse, 0) //nolint:errcheck
}

infolog.Printf("Resolved account %s to id %s", accountName, accountIdURLResponse.Id)

Expand Down
111 changes: 97 additions & 14 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fireboltgosdk
import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
Expand All @@ -18,6 +19,12 @@ func init() {

var originalEndpoint string

func raiseIfError(t *testing.T, err error) {
if err != nil {
t.Errorf("Encountered error %s", err)
}
}

// TestCacheAccessToken tests that a token is cached during authentication and reused for subsequent requests
func TestCacheAccessToken(t *testing.T) {
var fetchTokenCount = 0
Expand All @@ -39,9 +46,7 @@ func TestCacheAccessToken(t *testing.T) {
client.accessTokenGetter = client.getAccessToken
for i := 0; i < 3; i++ {
resp := client.request(context.TODO(), "GET", server.URL, nil, "")
if resp.err != nil {
t.Errorf("Did not expect an error %s", resp.err)
}
raiseIfError(t, resp.err)
}

token, _ := getAccessTokenServiceAccount("client_id", "", server.URL, "")
Expand Down Expand Up @@ -172,6 +177,10 @@ func clientFactory(apiEndpoint string) Client {
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams
err := initialiseCaches()
if err != nil {
log.Printf("Error while initializing caches: %s", err)
}
return client
}

Expand Down Expand Up @@ -231,6 +240,43 @@ func TestGetSystemEngineURLReturnsErrorOn404(t *testing.T) {
}
}

func TestGetSystemEngineURLCaching(t *testing.T) {
testAccountName := "testAccount"

urlCalled := 0

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"engineUrl": "https://my_url.com"}`))
urlCalled++
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))
defer server.Close()
prepareEnvVariablesForTest(t, server)

var client = clientFactory(server.URL).(*ClientImpl)

var err error
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
// Create a new client

client = clientFactory(server.URL).(*ClientImpl)
_, _, err = client.getSystemEngineURLAndParameters(context.Background(), testAccountName, "")
raiseIfError(t, err)
// Still only one call, as the cache is shared between clients
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
}

func TestGetAccountInfoReturnsErrorOn404(t *testing.T) {
testAccountName := "testAccount"
server, client := setupTestServerAndClient(t, testAccountName)
Expand Down Expand Up @@ -267,15 +313,58 @@ func TestGetAccountInfo(t *testing.T) {

// Call the getAccountID method and check if it returns the correct account ID and version
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
if err != nil {
t.Errorf("Expected no error, got %s", err)
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
if accountVersion != 2 {
t.Errorf("Expected account version to be 2, got %d", accountVersion)
}
}

func TestGetAccountInfoCached(t *testing.T) {
testAccountName := "testAccount"

urlCalled := 0

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(AccountInfoByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"id": "account_id", "infraVersion": 2}`))
urlCalled++
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))

prepareEnvVariablesForTest(t, server)

var client = clientFactory(server.URL).(*ClientImpl)

// Account info should be fetched from the cache so the server should not be called
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
if accountVersion != 2 {
t.Errorf("Expected account version to be 2, got %d", accountVersion)
}
url := fmt.Sprintf(server.URL+AccountInfoByAccountName, testAccountName)
if AccountCache.Get(url) == nil {
t.Errorf("Expected account info to be cached")
}
_, _, err = client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
client = clientFactory(server.URL).(*ClientImpl)
_, _, err = client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
// Still only one call, as the cache is shared between clients
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
}

func TestGetAccountInfoDefaultVersion(t *testing.T) {
Expand All @@ -299,9 +388,7 @@ func TestGetAccountInfoDefaultVersion(t *testing.T) {

// Call the getAccountID method and check if it returns the correct account ID and version
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
if err != nil {
t.Errorf("Expected no error, got %s", err)
}
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
Expand Down Expand Up @@ -340,9 +427,7 @@ func TestUpdateEndpoint(t *testing.T) {
engineEndpoint = value
},
})
if err != nil {
t.Errorf("Error during query execution with update parameters header in response %s", err)
}
raiseIfError(t, err)
if params["query"] != "param" {
t.Errorf("Query parameter was not set correctly. Expected 'param' but was %s", params["query"])
}
Expand Down Expand Up @@ -375,9 +460,7 @@ func TestResetSession(t *testing.T) {
resetCalled = true
},
})
if err != nil {
t.Errorf("Error during query execution with reset session header in response %s", err)
}
raiseIfError(t, err)
if !resetCalled {
t.Errorf("Reset session was not called")
}
Expand Down

0 comments on commit 39fc87d

Please sign in to comment.