Skip to content

Commit

Permalink
fix: Better error handling when account is not correct (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiurin authored Dec 20, 2023
1 parent 375449d commit c3f9aa0
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 22 deletions.
13 changes: 11 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const engineInfoSQL = `
SELECT url, status, attached_to FROM information_schema.engines
WHERE engine_name='%s'
`
const accountError = `account '%s' does not exist in this organization or is not authorized.
Please verify the account name and make sure your service account has the
correct RBAC permissions and is linked to a user`

func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, error) {
client := &ClientImpl{
Expand Down Expand Up @@ -91,7 +94,10 @@ func (c *ClientImpl) getSystemEngineURL(ctx context.Context, accountName string)

url := fmt.Sprintf(c.ApiEndpoint+EngineUrlByAccountName, accountName)

response, err := c.request(ctx, "GET", url, make(map[string]string), "")
response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "")
if err_code == 404 {
return "", fmt.Errorf(accountError, accountName)
}
if err != nil {
return "", ConstructNestedError("error during system engine url http request", err)
}
Expand All @@ -114,7 +120,10 @@ func (c *ClientImpl) getAccountID(ctx context.Context, accountName string) (stri

url := fmt.Sprintf(c.ApiEndpoint+AccountIdByAccountName, accountName)

response, err := c.request(ctx, "GET", url, make(map[string]string), "")
response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "")
if err_code == 404 {
return "", fmt.Errorf(accountError, accountName)
}
if err != nil {
return "", ConstructNestedError("error during account id resolution http request", err)
}
Expand Down
14 changes: 7 additions & 7 deletions client_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (c *BaseClient) Query(ctx context.Context, engineUrl, databaseName, query s
return nil, err
}

response, err := c.request(ctx, "POST", engineUrl, params, query)
response, err, _ := c.request(ctx, "POST", engineUrl, params, query)
if err != nil {
return nil, ConstructNestedError("error during query request", err)
}
Expand All @@ -61,16 +61,16 @@ func (c *BaseClient) Query(ctx context.Context, engineUrl, databaseName, query s

// request fetches an access token from the cache or re-authenticate when the access token is not available in the cache
// and sends a request using that token
func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) ([]byte, error) {
func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) ([]byte, error, int) {
var err error

if c.accessTokenGetter == nil {
return nil, errors.New("accessTokenGetter is not set")
return nil, errors.New("accessTokenGetter is not set"), 0
}

accessToken, err := c.accessTokenGetter()
if err != nil {
return nil, ConstructNestedError("error while getting access token", err)
return nil, ConstructNestedError("error while getting access token", err), 0
}
var response []byte
var responseCode int
Expand All @@ -81,12 +81,12 @@ func (c *BaseClient) request(ctx context.Context, method string, url string, par
// Refreshing the access token as it is expired
accessToken, err = c.accessTokenGetter()
if err != nil {
return nil, ConstructNestedError("error while refreshing access token", err)
return nil, ConstructNestedError("error while refreshing access token", err), 0
}
// Trying to send the same request again now that the access token has been refreshed
response, err, _ = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON)
response, err, responseCode = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON)
}
return response, err
return response, err, responseCode
}

// makeCanonicalUrl checks whether url starts with https:// and if not prepends it
Expand Down
60 changes: 56 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package fireboltgosdk

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strconv"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -37,7 +39,7 @@ func TestCacheAccessToken(t *testing.T) {
client.accessTokenGetter = client.getAccessToken
var err error
for i := 0; i < 3; i++ {
_, err = client.request(context.TODO(), "GET", server.URL, nil, "")
_, err, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
if err != nil {
t.Errorf("Did not expect an error %s", err)
}
Expand Down Expand Up @@ -81,7 +83,7 @@ func TestRefreshTokenOn401(t *testing.T) {
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")

if getCachedAccessToken("client_id", server.URL) != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
Expand Down Expand Up @@ -118,10 +120,10 @@ func TestFetchTokenWhenExpired(t *testing.T) {
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
// Waiting for the token to get expired
time.Sleep(2 * time.Millisecond)
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")

token, _ := getAccessTokenUsernamePassword("client_id", "", server.URL, "")

Expand Down Expand Up @@ -175,3 +177,53 @@ func getAuthResponse(expiry int) []byte {
}`
return []byte(response)
}

func setupTestServerAndClient(t *testing.T, testAccountName string) (*httptest.Server, *ClientImpl) {
// Create a mock server that returns a 404 status code
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) || req.URL.Path == fmt.Sprintf(AccountIdByAccountName, testAccountName) {
rw.WriteHeader(http.StatusNotFound)
} else {
_, _ = rw.Write(getAuthResponse(10000))
}
}))

prepareEnvVariablesForTest(t, server)
client := &ClientImpl{
BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL},
}
client.accessTokenGetter = client.getAccessToken
client.parameterGetter = client.getQueryParams

return server, client
}

func TestGetSystemEngineURLReturnsErrorOn404(t *testing.T) {
testAccountName := "testAccount"
server, client := setupTestServerAndClient(t, testAccountName)
defer server.Close()

// Call the getSystemEngineURL method and check if it returns an error
_, err := client.getSystemEngineURL(context.Background(), testAccountName)
if err == nil {
t.Errorf("Expected an error, got nil")
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) {
t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error())
}
}

func TestGetAccountIdReturnsErrorOn404(t *testing.T) {
testAccountName := "testAccount"
server, client := setupTestServerAndClient(t, testAccountName)
defer server.Close()

// Call the getAccountID method and check if it returns an error
_, err := client.getAccountID(context.Background(), testAccountName)
if err == nil {
t.Errorf("Expected an error, got nil")
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) {
t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error())
}
}
10 changes: 5 additions & 5 deletions client_v0.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (c *ClientImplV0) getAccountIDByName(ctx context.Context, accountName strin

params := map[string]string{"account_name": accountName}

response, err := c.request(ctx, "GET", c.ApiEndpoint+AccountIdByNameURL, params, "")
response, err, _ := c.request(ctx, "GET", c.ApiEndpoint+AccountIdByNameURL, params, "")
if err != nil {
return "", ConstructNestedError("error during getting account id by name request", err)
}
Expand All @@ -64,7 +64,7 @@ func (c *ClientImplV0) getDefaultAccountID(ctx context.Context) (string, error)
Account AccountResponse `json:"account"`
}

response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+DefaultAccountURL), make(map[string]string), "")
response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+DefaultAccountURL), make(map[string]string), "")
if err != nil {
return "", ConstructNestedError("error during getting default account id request", err)
}
Expand Down Expand Up @@ -105,7 +105,7 @@ func (c *ClientImplV0) getEngineIdByName(ctx context.Context, engineName string,
}

params := map[string]string{"engine_name": engineName}
response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineIdByNameURL, accountId), params, "")
response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineIdByNameURL, accountId), params, "")
if err != nil {
return "", ConstructNestedError("error during getting engine id by name request", err)
}
Expand All @@ -128,7 +128,7 @@ func (c *ClientImplV0) getEngineUrlById(ctx context.Context, engineId string, ac
Engine EngineResponse `json:"engine"`
}

response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineByIdURL, accountId, engineId), make(map[string]string), "")
response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineByIdURL, accountId, engineId), make(map[string]string), "")

if err != nil {
return "", ConstructNestedError("error during getting engine url by id request", err)
Expand Down Expand Up @@ -167,7 +167,7 @@ func (c *ClientImplV0) getEngineUrlByDatabase(ctx context.Context, databaseName
}

params := map[string]string{"database_name": databaseName}
response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineUrlByDatabaseNameURL, accountId), params, "")
response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineUrlByDatabaseNameURL, accountId), params, "")
if err != nil {
return "", ConstructNestedError("error during getting engine url by database request", err)
}
Expand Down
8 changes: 4 additions & 4 deletions client_v0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestCacheAccessTokenV0(t *testing.T) {
client.accessTokenGetter = client.getAccessToken
var err error
for i := 0; i < 3; i++ {
_, err = client.request(context.TODO(), "GET", server.URL, nil, "")
_, err, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
if err != nil {
t.Errorf("Did not expect an error %s", err)
}
Expand Down Expand Up @@ -75,7 +75,7 @@ func TestRefreshTokenOn401V0(t *testing.T) {
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")

if getCachedAccessToken("[email protected]", server.URL) != "aMysteriousToken" {
t.Errorf("Did not fetch missing token")
Expand Down Expand Up @@ -112,10 +112,10 @@ func TestFetchTokenWhenExpiredV0(t *testing.T) {
BaseClient{ClientID: "[email protected]", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"},
}
client.accessTokenGetter = client.getAccessToken
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
// Waiting for the token to get expired
time.Sleep(2 * time.Millisecond)
_, _ = client.request(context.TODO(), "GET", server.URL, nil, "")
_, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "")

token, _ := getAccessTokenUsernamePassword("[email protected]", "", server.URL, "")

Expand Down
93 changes: 93 additions & 0 deletions driver_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
engineNameMock string
engineUrlMock string
accountNameMock string
serviceAccountNoUserName string
clientMock *ClientImpl
clientMockWithAccount *ClientImpl
)
Expand Down Expand Up @@ -71,6 +72,7 @@ func init() {
clientMockWithAccount = clientWithAccount.(*ClientImpl)
clientMockWithAccount.ConnectedToSystemEngine = true
engineUrlMock = getEngineURL()
serviceAccountNoUserName = databaseMock + "_sa_no_user"
}

func getEngineURL() string {
Expand Down Expand Up @@ -265,3 +267,94 @@ func containsEngine(rows *sql.Rows, engineToFind string) (bool, error) {
}
return false, nil
}

func TestIncorrectAccount(t *testing.T) {
_, err := Authenticate(&fireboltSettings{
clientID: clientIdMock,
clientSecret: clientSecretMock,
accountName: "incorrect_account",
engineName: engineNameMock,
database: databaseMock,
newVersion: true,
}, GetHostNameURL())
if err == nil {
t.Errorf("Authentication didn't return an error, although it should")
}
if !strings.HasPrefix(err.Error(), "error during getting account id: account 'incorrect_account' does not exist") {
t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error())
}
}

// function that creates a service account and returns its id and secret
func createServiceAccountNoUser(t *testing.T, serviceAccountName string) (string, string) {
serviceAccountDescription := "test_service_account_description"

db, err := sql.Open("firebolt", dsnSystemEngineMock)
if err != nil {
t.Errorf("failed unexpectedly with %v", err)
}
// create service account
createServiceAccountQuery := fmt.Sprintf("CREATE SERVICE ACCOUNT \"%s\" WITH DESCRIPTION = \"%s\"", serviceAccountName, serviceAccountDescription)
_, err = db.Query(createServiceAccountQuery)
if err != nil {
t.Errorf("The query %s returned an error: %v", createServiceAccountQuery, err)
}
// generate credentials for service account
generateServiceAccountKeyQuery := fmt.Sprintf("CALL fb_GENERATESERVICEACCOUNTKEY('%s')", serviceAccountName)
// get service account id and secret from the result
rows, err := db.Query(generateServiceAccountKeyQuery)
var serviceAccountNameReturned, serviceAccountID, serviceAccountSecret string
for rows.Next() {
if err := rows.Scan(&serviceAccountNameReturned, &serviceAccountID, &serviceAccountSecret); err != nil {
t.Errorf("Failed to retrieve service account id and secret: %v", err)
}
}
// Currently this is bugged so retrieve id via a query if not returned otherwise. FIR-28719
if serviceAccountID == "" {
getServiceAccountIDQuery := fmt.Sprintf("SELECT service_account_id FROM information_schema.service_accounts WHERE service_account_name = '%s'", serviceAccountName)
rows, err := db.Query(getServiceAccountIDQuery)
if err != nil {
t.Errorf("Failed to retrieve service account id: %v", err)
}
for rows.Next() {
if err := rows.Scan(&serviceAccountID); err != nil {
t.Errorf("Failed to retrieve service account id: %v", err)
}
}
}
return serviceAccountID, serviceAccountSecret
}

func deleteServiceAccount(t *testing.T, serviceAccountName string) {
db, err := sql.Open("firebolt", dsnSystemEngineMock)
if err != nil {
t.Errorf("failed unexpectedly with %v", err)
}
// delete service account
deleteServiceAccountQuery := fmt.Sprintf("DROP SERVICE ACCOUNT \"%s\"", serviceAccountName)
_, err = db.Query(deleteServiceAccountQuery)
if err != nil {
t.Errorf("The query %s returned an error: %v", deleteServiceAccountQuery, err)
}
}

// test authentication with service account without a user fails
func TestServiceAccountAuthentication(t *testing.T) {
serviceAccountID, serviceAccountSecret := createServiceAccountNoUser(t, serviceAccountNoUserName)
defer deleteServiceAccount(t, serviceAccountNoUserName) // Delete service account after the test

_, err := Authenticate(&fireboltSettings{
clientID: serviceAccountID,
clientSecret: serviceAccountSecret,
accountName: accountNameMock,
engineName: engineNameMock,
database: databaseMock,
newVersion: true,
}, GetHostNameURL())
if err == nil {
t.Errorf("Authentication didn't return an error, although it should")
}
if !strings.HasPrefix(err.Error(), fmt.Sprintf("error during getting account id: account '%s' does not exist", accountNameMock)) {
t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error())
}
}

0 comments on commit c3f9aa0

Please sign in to comment.