diff --git a/client.go b/client.go index af7c2bc..23fc601 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,9 @@ const engineInfoSQL = ` SELECT url, status, attached_to FROM information_schema.engines WHERE engine_name='%s' ` +const accountError = `account '%s' does not exist in this organization or is not authorized. +Please verify the account name and make sure your service account has the +correct RBAC permissions and is linked to a user` func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, error) { client := &ClientImpl{ @@ -91,7 +94,10 @@ func (c *ClientImpl) getSystemEngineURL(ctx context.Context, accountName string) url := fmt.Sprintf(c.ApiEndpoint+EngineUrlByAccountName, accountName) - response, err := c.request(ctx, "GET", url, make(map[string]string), "") + response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "") + if err_code == 404 { + return "", fmt.Errorf(accountError, accountName) + } if err != nil { return "", ConstructNestedError("error during system engine url http request", err) } @@ -114,7 +120,10 @@ func (c *ClientImpl) getAccountID(ctx context.Context, accountName string) (stri url := fmt.Sprintf(c.ApiEndpoint+AccountIdByAccountName, accountName) - response, err := c.request(ctx, "GET", url, make(map[string]string), "") + response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "") + if err_code == 404 { + return "", fmt.Errorf(accountError, accountName) + } if err != nil { return "", ConstructNestedError("error during account id resolution http request", err) } diff --git a/client_base.go b/client_base.go index dfcc20a..b5151cc 100644 --- a/client_base.go +++ b/client_base.go @@ -40,7 +40,7 @@ func (c *BaseClient) Query(ctx context.Context, engineUrl, databaseName, query s return nil, err } - response, err := c.request(ctx, "POST", engineUrl, params, query) + response, err, _ := c.request(ctx, "POST", engineUrl, params, query) if err != nil { return nil, ConstructNestedError("error during query request", err) } @@ -61,16 +61,16 @@ func (c *BaseClient) Query(ctx context.Context, engineUrl, databaseName, query s // request fetches an access token from the cache or re-authenticate when the access token is not available in the cache // and sends a request using that token -func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) ([]byte, error) { +func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) ([]byte, error, int) { var err error if c.accessTokenGetter == nil { - return nil, errors.New("accessTokenGetter is not set") + return nil, errors.New("accessTokenGetter is not set"), 0 } accessToken, err := c.accessTokenGetter() if err != nil { - return nil, ConstructNestedError("error while getting access token", err) + return nil, ConstructNestedError("error while getting access token", err), 0 } var response []byte var responseCode int @@ -81,12 +81,12 @@ func (c *BaseClient) request(ctx context.Context, method string, url string, par // Refreshing the access token as it is expired accessToken, err = c.accessTokenGetter() if err != nil { - return nil, ConstructNestedError("error while refreshing access token", err) + return nil, ConstructNestedError("error while refreshing access token", err), 0 } // Trying to send the same request again now that the access token has been refreshed - response, err, _ = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON) + response, err, responseCode = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON) } - return response, err + return response, err, responseCode } // makeCanonicalUrl checks whether url starts with https:// and if not prepends it diff --git a/client_test.go b/client_test.go index 63e55b8..92554a8 100644 --- a/client_test.go +++ b/client_test.go @@ -2,10 +2,12 @@ package fireboltgosdk import ( "context" + "fmt" "net/http" "net/http/httptest" "os" "strconv" + "strings" "testing" "time" ) @@ -37,7 +39,7 @@ func TestCacheAccessToken(t *testing.T) { client.accessTokenGetter = client.getAccessToken var err error for i := 0; i < 3; i++ { - _, err = client.request(context.TODO(), "GET", server.URL, nil, "") + _, err, _ = client.request(context.TODO(), "GET", server.URL, nil, "") if err != nil { t.Errorf("Did not expect an error %s", err) } @@ -81,7 +83,7 @@ func TestRefreshTokenOn401(t *testing.T) { BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"}, } client.accessTokenGetter = client.getAccessToken - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") if getCachedAccessToken("client_id", server.URL) != "aMysteriousToken" { t.Errorf("Did not fetch missing token") @@ -118,10 +120,10 @@ func TestFetchTokenWhenExpired(t *testing.T) { BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL, UserAgent: "userAgent"}, } client.accessTokenGetter = client.getAccessToken - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") // Waiting for the token to get expired time.Sleep(2 * time.Millisecond) - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") token, _ := getAccessTokenUsernamePassword("client_id", "", server.URL, "") @@ -175,3 +177,53 @@ func getAuthResponse(expiry int) []byte { }` return []byte(response) } + +func setupTestServerAndClient(t *testing.T, testAccountName string) (*httptest.Server, *ClientImpl) { + // Create a mock server that returns a 404 status code + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.URL.Path == fmt.Sprintf(EngineUrlByAccountName, testAccountName) || req.URL.Path == fmt.Sprintf(AccountIdByAccountName, testAccountName) { + rw.WriteHeader(http.StatusNotFound) + } else { + _, _ = rw.Write(getAuthResponse(10000)) + } + })) + + prepareEnvVariablesForTest(t, server) + client := &ClientImpl{ + BaseClient: BaseClient{ClientID: "client_id", ClientSecret: "client_secret", ApiEndpoint: server.URL}, + } + client.accessTokenGetter = client.getAccessToken + client.parameterGetter = client.getQueryParams + + return server, client +} + +func TestGetSystemEngineURLReturnsErrorOn404(t *testing.T) { + testAccountName := "testAccount" + server, client := setupTestServerAndClient(t, testAccountName) + defer server.Close() + + // Call the getSystemEngineURL method and check if it returns an error + _, err := client.getSystemEngineURL(context.Background(), testAccountName) + if err == nil { + t.Errorf("Expected an error, got nil") + } + if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) { + t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error()) + } +} + +func TestGetAccountIdReturnsErrorOn404(t *testing.T) { + testAccountName := "testAccount" + server, client := setupTestServerAndClient(t, testAccountName) + defer server.Close() + + // Call the getAccountID method and check if it returns an error + _, err := client.getAccountID(context.Background(), testAccountName) + if err == nil { + t.Errorf("Expected an error, got nil") + } + if !strings.HasPrefix(err.Error(), fmt.Sprintf("account '%s' does not exist", testAccountName)) { + t.Errorf("Expected error to start with \"account '%s' does not exist\", got \"%s\"", testAccountName, err.Error()) + } +} diff --git a/client_v0.go b/client_v0.go index 0888264..4288172 100644 --- a/client_v0.go +++ b/client_v0.go @@ -42,7 +42,7 @@ func (c *ClientImplV0) getAccountIDByName(ctx context.Context, accountName strin params := map[string]string{"account_name": accountName} - response, err := c.request(ctx, "GET", c.ApiEndpoint+AccountIdByNameURL, params, "") + response, err, _ := c.request(ctx, "GET", c.ApiEndpoint+AccountIdByNameURL, params, "") if err != nil { return "", ConstructNestedError("error during getting account id by name request", err) } @@ -64,7 +64,7 @@ func (c *ClientImplV0) getDefaultAccountID(ctx context.Context) (string, error) Account AccountResponse `json:"account"` } - response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+DefaultAccountURL), make(map[string]string), "") + response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+DefaultAccountURL), make(map[string]string), "") if err != nil { return "", ConstructNestedError("error during getting default account id request", err) } @@ -105,7 +105,7 @@ func (c *ClientImplV0) getEngineIdByName(ctx context.Context, engineName string, } params := map[string]string{"engine_name": engineName} - response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineIdByNameURL, accountId), params, "") + response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineIdByNameURL, accountId), params, "") if err != nil { return "", ConstructNestedError("error during getting engine id by name request", err) } @@ -128,7 +128,7 @@ func (c *ClientImplV0) getEngineUrlById(ctx context.Context, engineId string, ac Engine EngineResponse `json:"engine"` } - response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineByIdURL, accountId, engineId), make(map[string]string), "") + response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineByIdURL, accountId, engineId), make(map[string]string), "") if err != nil { return "", ConstructNestedError("error during getting engine url by id request", err) @@ -167,7 +167,7 @@ func (c *ClientImplV0) getEngineUrlByDatabase(ctx context.Context, databaseName } params := map[string]string{"database_name": databaseName} - response, err := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineUrlByDatabaseNameURL, accountId), params, "") + response, err, _ := c.request(ctx, "GET", fmt.Sprintf(c.ApiEndpoint+EngineUrlByDatabaseNameURL, accountId), params, "") if err != nil { return "", ConstructNestedError("error during getting engine url by database request", err) } diff --git a/client_v0_test.go b/client_v0_test.go index 0ddb6c2..0212630 100644 --- a/client_v0_test.go +++ b/client_v0_test.go @@ -31,7 +31,7 @@ func TestCacheAccessTokenV0(t *testing.T) { client.accessTokenGetter = client.getAccessToken var err error for i := 0; i < 3; i++ { - _, err = client.request(context.TODO(), "GET", server.URL, nil, "") + _, err, _ = client.request(context.TODO(), "GET", server.URL, nil, "") if err != nil { t.Errorf("Did not expect an error %s", err) } @@ -75,7 +75,7 @@ func TestRefreshTokenOn401V0(t *testing.T) { BaseClient{ClientID: "ClientID@firebolt.io", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"}, } client.accessTokenGetter = client.getAccessToken - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") if getCachedAccessToken("ClientID@firebolt.io", server.URL) != "aMysteriousToken" { t.Errorf("Did not fetch missing token") @@ -112,10 +112,10 @@ func TestFetchTokenWhenExpiredV0(t *testing.T) { BaseClient{ClientID: "ClientID@firebolt.io", ClientSecret: "password", ApiEndpoint: server.URL, UserAgent: "userAgent"}, } client.accessTokenGetter = client.getAccessToken - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") // Waiting for the token to get expired time.Sleep(2 * time.Millisecond) - _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") + _, _, _ = client.request(context.TODO(), "GET", server.URL, nil, "") token, _ := getAccessTokenUsernamePassword("ClientID@firebolt.io", "", server.URL, "") diff --git a/driver_integration_test.go b/driver_integration_test.go index c83aa1e..5e68881 100644 --- a/driver_integration_test.go +++ b/driver_integration_test.go @@ -29,6 +29,7 @@ var ( engineNameMock string engineUrlMock string accountNameMock string + serviceAccountNoUserName string clientMock *ClientImpl clientMockWithAccount *ClientImpl ) @@ -71,6 +72,7 @@ func init() { clientMockWithAccount = clientWithAccount.(*ClientImpl) clientMockWithAccount.ConnectedToSystemEngine = true engineUrlMock = getEngineURL() + serviceAccountNoUserName = databaseMock + "_sa_no_user" } func getEngineURL() string { @@ -265,3 +267,94 @@ func containsEngine(rows *sql.Rows, engineToFind string) (bool, error) { } return false, nil } + +func TestIncorrectAccount(t *testing.T) { + _, err := Authenticate(&fireboltSettings{ + clientID: clientIdMock, + clientSecret: clientSecretMock, + accountName: "incorrect_account", + engineName: engineNameMock, + database: databaseMock, + newVersion: true, + }, GetHostNameURL()) + if err == nil { + t.Errorf("Authentication didn't return an error, although it should") + } + if !strings.HasPrefix(err.Error(), "error during getting account id: account 'incorrect_account' does not exist") { + t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error()) + } +} + +// function that creates a service account and returns its id and secret +func createServiceAccountNoUser(t *testing.T, serviceAccountName string) (string, string) { + serviceAccountDescription := "test_service_account_description" + + db, err := sql.Open("firebolt", dsnSystemEngineMock) + if err != nil { + t.Errorf("failed unexpectedly with %v", err) + } + // create service account + createServiceAccountQuery := fmt.Sprintf("CREATE SERVICE ACCOUNT \"%s\" WITH DESCRIPTION = \"%s\"", serviceAccountName, serviceAccountDescription) + _, err = db.Query(createServiceAccountQuery) + if err != nil { + t.Errorf("The query %s returned an error: %v", createServiceAccountQuery, err) + } + // generate credentials for service account + generateServiceAccountKeyQuery := fmt.Sprintf("CALL fb_GENERATESERVICEACCOUNTKEY('%s')", serviceAccountName) + // get service account id and secret from the result + rows, err := db.Query(generateServiceAccountKeyQuery) + var serviceAccountNameReturned, serviceAccountID, serviceAccountSecret string + for rows.Next() { + if err := rows.Scan(&serviceAccountNameReturned, &serviceAccountID, &serviceAccountSecret); err != nil { + t.Errorf("Failed to retrieve service account id and secret: %v", err) + } + } + // Currently this is bugged so retrieve id via a query if not returned otherwise. FIR-28719 + if serviceAccountID == "" { + getServiceAccountIDQuery := fmt.Sprintf("SELECT service_account_id FROM information_schema.service_accounts WHERE service_account_name = '%s'", serviceAccountName) + rows, err := db.Query(getServiceAccountIDQuery) + if err != nil { + t.Errorf("Failed to retrieve service account id: %v", err) + } + for rows.Next() { + if err := rows.Scan(&serviceAccountID); err != nil { + t.Errorf("Failed to retrieve service account id: %v", err) + } + } + } + return serviceAccountID, serviceAccountSecret +} + +func deleteServiceAccount(t *testing.T, serviceAccountName string) { + db, err := sql.Open("firebolt", dsnSystemEngineMock) + if err != nil { + t.Errorf("failed unexpectedly with %v", err) + } + // delete service account + deleteServiceAccountQuery := fmt.Sprintf("DROP SERVICE ACCOUNT \"%s\"", serviceAccountName) + _, err = db.Query(deleteServiceAccountQuery) + if err != nil { + t.Errorf("The query %s returned an error: %v", deleteServiceAccountQuery, err) + } +} + +// test authentication with service account without a user fails +func TestServiceAccountAuthentication(t *testing.T) { + serviceAccountID, serviceAccountSecret := createServiceAccountNoUser(t, serviceAccountNoUserName) + defer deleteServiceAccount(t, serviceAccountNoUserName) // Delete service account after the test + + _, err := Authenticate(&fireboltSettings{ + clientID: serviceAccountID, + clientSecret: serviceAccountSecret, + accountName: accountNameMock, + engineName: engineNameMock, + database: databaseMock, + newVersion: true, + }, GetHostNameURL()) + if err == nil { + t.Errorf("Authentication didn't return an error, although it should") + } + if !strings.HasPrefix(err.Error(), fmt.Sprintf("error during getting account id: account '%s' does not exist", accountNameMock)) { + t.Errorf("Authentication didn't return an error with correct message, got: %s", err.Error()) + } +}