Skip to content

Commit

Permalink
refactor: FIR-35471 remove account resolve call (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Sep 4, 2024
1 parent 3bf942b commit adc8337
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 205 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ go get github.com/firebolt-db/firebolt-go-sdk
### Example
Here is an example of establishing a connection and executing a simple select query.
For it to run successfully, you have to specify your credentials, and have a default engine up and running.

```go
package main

Expand Down
50 changes: 0 additions & 50 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ var URLCache cache.Cache
type ClientImpl struct {
ConnectedToSystemEngine bool
AccountName string
AccountVersion int
BaseClient
}

Expand Down Expand Up @@ -56,11 +55,6 @@ func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, er
infolog.Printf("Error during cache initialisation: %v", err)
}

var err error
client.AccountID, client.AccountVersion, err = client.getAccountInfo(context.Background(), settings.accountName)
if err != nil {
return nil, ConstructNestedError("error during getting account id", err)
}
return client, nil
}

Expand Down Expand Up @@ -124,50 +118,6 @@ func (c *ClientImpl) getSystemEngineURLAndParameters(ctx context.Context, accoun
return engineUrl, parameters, nil
}

func (c *ClientImpl) getAccountInfo(ctx context.Context, accountName string) (string, int, error) {

type AccountIdURLResponse struct {
Id string `json:"id"`
Region string `json:"region"`
InfraVersion int `json:"infraVersion"`
}

url := fmt.Sprintf(c.ApiEndpoint+AccountInfoByAccountName, accountName)

if AccountCache != nil {
val := AccountCache.Get(url)
if val != nil {
if accountInfo, ok := val.(AccountIdURLResponse); ok {
infolog.Printf("Resolved account %s to id %s from cache", accountName, accountInfo.Id)
return accountInfo.Id, accountInfo.InfraVersion, nil
}
}
}
infolog.Printf("Getting account ID for '%s'", accountName)

resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
return "", 0, fmt.Errorf(accountError, accountName)
}
if resp.err != nil {
return "", 0, ConstructNestedError("error during account id resolution http request", resp.err)
}

var accountIdURLResponse AccountIdURLResponse
// InfraVersion should default to 1 if not present
accountIdURLResponse.InfraVersion = 1
if err := json.Unmarshal(resp.data, &accountIdURLResponse); err != nil {
return "", 0, ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(resp.data)))
}
if AccountCache != nil {
AccountCache.Put(url, accountIdURLResponse, 0) //nolint:errcheck
}

infolog.Printf("Resolved account %s to id %s", accountName, accountIdURLResponse.Id)

return accountIdURLResponse.Id, accountIdURLResponse.InfraVersion, nil
}

func (c *ClientImpl) getQueryParams(setStatements map[string]string) (map[string]string, error) {
params := map[string]string{"output_format": outputFormat}
for setKey, setValue := range setStatements {
Expand Down
4 changes: 0 additions & 4 deletions client_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ type Client interface {
type BaseClient struct {
ClientID string
ClientSecret string
AccountID string
ApiEndpoint string
UserAgent string
parameterGetter func(map[string]string) (map[string]string, error)
Expand Down Expand Up @@ -134,9 +133,6 @@ func (c *BaseClient) handleUpdateEndpoint(updateEndpointRaw string, control conn
if err != nil {
return corruptUrlError
}
if accId, ok := newParameters["account_id"]; ok && accId[0] != c.AccountID {
return errors.New("Failed to execute USE ENGINE command. Account parameter mismatch. Contact support")
}
// set engine URL as a full URL excluding query parameters
control.setEngineURL(updateEndpoint)
// update client parameters with new parameters
Expand Down
122 changes: 1 addition & 121 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func getAuthResponse(expiry int) []byte {
func setupTestServerAndClient(t *testing.T, testAccountName string) (*httptest.Server, *ClientImpl) {
// Create a mock server that returns a 404 status code
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) || req.URL.Path == fmt.Sprintf(AccountInfoByAccountName, testAccountName) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) {
rw.WriteHeader(http.StatusNotFound)
} else {
_, _ = rw.Write(getAuthResponse(10000))
Expand Down Expand Up @@ -277,126 +277,6 @@ func TestGetSystemEngineURLCaching(t *testing.T) {
}
}

func TestGetAccountInfoReturnsErrorOn404(t *testing.T) {
testAccountName := "testAccount"
server, client := setupTestServerAndClient(t, testAccountName)
defer server.Close()

// Call the getAccountID method and check if it returns an error
_, _, err := client.getAccountInfo(context.Background(), testAccountName)
if err == nil {
t.Errorf("Expected an error, got nil")
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) {
t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error())
}
}

func TestGetAccountInfo(t *testing.T) {
testAccountName := "testAccount"

// Create a mock server that returns a 200 status code
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(AccountInfoByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"id": "account_id", "infraVersion": 2}`))
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))

prepareEnvVariablesForTest(t, server)
client := &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams

// Call the getAccountID method and check if it returns the correct account ID and version
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
if accountVersion != 2 {
t.Errorf("Expected account version to be 2, got %d", accountVersion)
}
}

func TestGetAccountInfoCached(t *testing.T) {
testAccountName := "testAccount"

urlCalled := 0

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(AccountInfoByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"id": "account_id", "infraVersion": 2}`))
urlCalled++
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))

prepareEnvVariablesForTest(t, server)

var client = clientFactory(server.URL).(*ClientImpl)

// Account info should be fetched from the cache so the server should not be called
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
if accountVersion != 2 {
t.Errorf("Expected account version to be 2, got %d", accountVersion)
}
url := fmt.Sprintf(server.URL+AccountInfoByAccountName, testAccountName)
if AccountCache.Get(url) == nil {
t.Errorf("Expected account info to be cached")
}
_, _, err = client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
client = clientFactory(server.URL).(*ClientImpl)
_, _, err = client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
// Still only one call, as the cache is shared between clients
if urlCalled != 1 {
t.Errorf("Expected to call the server only once, got %d", urlCalled)
}
}

func TestGetAccountInfoDefaultVersion(t *testing.T) {
testAccountName := "testAccount"

// Create a mock server that returns a 200 status code
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(AccountInfoByAccountName, testAccountName) {
_, _ = rw.Write([]byte(`{"id": "account_id"}`))
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))

prepareEnvVariablesForTest(t, server)
client := &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams

// Call the getAccountID method and check if it returns the correct account ID and version
accountID, accountVersion, err := client.getAccountInfo(context.Background(), testAccountName)
raiseIfError(t, err)
if accountID != "account_id" {
t.Errorf("Expected account ID to be 'account_id', got %s", accountID)
}
if accountVersion != 1 {
t.Errorf("Expected account version to be 1, got %d", accountVersion)
}
}

func TestUpdateEndpoint(t *testing.T) {
var newEndpoint = "new-endpoint/path?query=param"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
1 change: 1 addition & 0 deletions client_v0.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
)

type ClientImplV0 struct {
AccountID string
BaseClient
}

Expand Down
5 changes: 5 additions & 0 deletions client_v0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func TestCacheAccessTokenV0(t *testing.T) {
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImplV0{
"",
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
Expand Down Expand Up @@ -71,6 +72,7 @@ func TestRefreshTokenOn401V0(t *testing.T) {
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImplV0{
"",
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
Expand Down Expand Up @@ -108,6 +110,7 @@ func TestFetchTokenWhenExpiredV0(t *testing.T) {
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImplV0{
"",
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
Expand Down Expand Up @@ -147,6 +150,7 @@ func TestUserAgentV0(t *testing.T) {
defer server.Close()
prepareEnvVariablesForTest(t, server)
var client = &ClientImplV0{
"",
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: userAgentValue},
}
client.accessTokenGetter = client.getAccessToken
Expand All @@ -160,6 +164,7 @@ func TestUserAgentV0(t *testing.T) {

func clientFactoryV0(apiEndpoint string) Client {
var client = &ClientImplV0{
"",
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: apiEndpoint},
}
client.accessTokenGetter = client.getAccessToken
Expand Down
34 changes: 6 additions & 28 deletions driver_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,6 @@ func TestDriverSystemEngineDbContext(t *testing.T) {
}
}

func TestIncorrectAccount(t *testing.T) {
_, err := Authenticate(&fireboltSettings{
clientID: clientIdMock,
clientSecret: clientSecretMock,
accountName: "incorrect_account",
engineName: engineNameMock,
database: databaseMock,
newVersion: true,
}, GetHostNameURL())
if err == nil {
t.Errorf("Authentication didn't return an error, although it should")
}
if !strings.HasPrefix(err.Error(), "error during getting account id: account 'incorrect_account' does not exist") {
t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error())
}
}

// function that creates a service account and returns its id and secret
func createServiceAccountNoUser(t *testing.T, serviceAccountName string) (string, string) {
serviceAccountDescription := "test_service_account_description"
Expand Down Expand Up @@ -283,21 +266,16 @@ func TestServiceAccountAuthentication(t *testing.T) {
serviceAccountID, serviceAccountSecret := createServiceAccountNoUser(t, serviceAccountNoUserName)
defer deleteServiceAccount(t, serviceAccountNoUserName) // Delete service account after the test

// Clear the cache to ensure that the new service account is used
AccountCache.ClearAll()
dsnNoUser := fmt.Sprintf(
"firebolt:///%s?account_name=%s&engine=%s&client_id=%s&client_secret=%s",
databaseMock, accountName, engineNameMock, serviceAccountID, serviceAccountSecret)

_, err := Authenticate(&fireboltSettings{
clientID: serviceAccountID,
clientSecret: serviceAccountSecret,
accountName: accountName,
engineName: engineNameMock,
database: databaseMock,
newVersion: true,
}, GetHostNameURL())
_, err := sql.Open("firebolt", dsnNoUser)
if err == nil {
t.Errorf("Authentication didn't return an error, although it should")
t.FailNow()
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("error during getting account id: account '%s' does not exist", accountName)) {
if !strings.Contains(err.Error(), fmt.Sprintf("Database '%s' does not exist or not authorized", databaseMock)) {
t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error())
}
}
Expand Down
1 change: 0 additions & 1 deletion urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fireboltgosdk
const (
ServiceAccountLoginURLSuffix = "/oauth/token"
EngineUrlByAccountName = "/web/v3/account/%s/engineUrl"
AccountInfoByAccountName = "/web/v3/account/%s/resolve"
//API v0
UsernamePasswordURLSuffix = "/auth/v1/login"
DefaultAccountURL = "/iam/v2/account"
Expand Down

0 comments on commit adc8337

Please sign in to comment.