Skip to content

Commit

Permalink
feat: Fir 30154 support use engine in go sdk (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Mar 25, 2024
1 parent ddc0a3f commit da68727
Show file tree
Hide file tree
Showing 19 changed files with 579 additions and 121 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/integration-tests-v2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ jobs:
firebolt-client-id: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }}
firebolt-client-secret: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN}}
api-endpoint: "api.staging.firebolt.io"
account: ${{ vars.FIREBOLT_ACCOUNT }}
account: ${{ vars.FIREBOLT_ACCOUNT_V1 }}
instance-type: "B2"

- name: Run integration tests
env:
DATABASE_NAME: ${{ steps.setup.outputs.database_name }}
ENGINE_NAME: ${{ steps.setup.outputs.engine_name }}
FIREBOLT_ENDPOINT: "api.staging.firebolt.io"
ACCOUNT_NAME: ${{ vars.FIREBOLT_ACCOUNT }}
ACCOUNT_NAME_V1: ${{ vars.FIREBOLT_ACCOUNT_V1 }}
ACCOUNT_NAME_V2: ${{ vars.FIREBOLT_ACCOUNT_V2 }}
CLIENT_ID: ${{ secrets.FIREBOLT_CLIENT_ID_STG_NEW_IDN }}
CLIENT_SECRET: ${{ secrets.FIREBOLT_CLIENT_SECRET_STG_NEW_IDN }}
run: |
Expand Down
90 changes: 66 additions & 24 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type ClientImpl struct {
ConnectedToSystemEngine bool
SystemEngineURL string
AccountVersion int
BaseClient
}

Expand All @@ -35,7 +36,7 @@ func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, er
client.accessTokenGetter = client.getAccessToken

var err error
client.AccountID, err = client.getAccountID(context.Background(), settings.accountName)
client.AccountID, client.AccountVersion, err = client.getAccountInfo(context.Background(), settings.accountName)
if err != nil {
return nil, ConstructNestedError("error during getting account id", err)
}
Expand All @@ -49,9 +50,7 @@ func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, er
func (c *ClientImpl) getEngineUrlStatusDBByName(ctx context.Context, engineName string, systemEngineUrl string) (string, string, string, error) {
infolog.Printf("Get info for engine '%s'", engineName)
engineSQL := fmt.Sprintf(engineInfoSQL, engineName)
queryRes, err := c.Query(ctx, systemEngineUrl, engineSQL, make(map[string]string), func(key, value string) {
// No need to support set statements for engine info query
})
queryRes, err := c.Query(ctx, systemEngineUrl, engineSQL, make(map[string]string), connectionControl{})
if err != nil {
return "", "", "", ConstructNestedError("error executing engine info sql query", err)
}
Expand Down Expand Up @@ -108,36 +107,44 @@ func (c *ClientImpl) getSystemEngineURL(ctx context.Context, accountName string)
if err := json.Unmarshal(resp.data, &systemEngineURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling system engine URL response", errors.New(string(resp.data)))
}
// Ignore any query parameters provided in the URL
engineUrl, _, err := splitEngineEndpoint(systemEngineURLResponse.EngineUrl)
if err != nil {
return "", ConstructNestedError("error during splitting system engine URL", err)
}

return systemEngineURLResponse.EngineUrl, nil
return engineUrl, nil
}

func (c *ClientImpl) getAccountID(ctx context.Context, accountName string) (string, error) {
func (c *ClientImpl) getAccountInfo(ctx context.Context, accountName string) (string, int, error) {
infolog.Printf("Getting account ID for '%s'", accountName)

type AccountIdURLResponse struct {
Id string `json:"id"`
Region string `json:"region"`
Id string `json:"id"`
Region string `json:"region"`
InfraVersion int `json:"infraVersion"`
}

url := fmt.Sprintf(c.ApiEndpoint+AccountIdByAccountName, accountName)
url := fmt.Sprintf(c.ApiEndpoint+AccountInfoByAccountName, accountName)

resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
return "", fmt.Errorf(accountError, accountName)
return "", 0, fmt.Errorf(accountError, accountName)
}
if resp.err != nil {
return "", ConstructNestedError("error during account id resolution http request", resp.err)
return "", 0, ConstructNestedError("error during account id resolution http request", resp.err)
}

var accountIdURLResponse AccountIdURLResponse
// InfraVersion should default to 1 if not present
accountIdURLResponse.InfraVersion = 1
if err := json.Unmarshal(resp.data, &accountIdURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(resp.data)))
return "", 0, ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(resp.data)))
}

infolog.Printf("Resolved account %s to id %s", accountName, accountIdURLResponse.Id)

return accountIdURLResponse.Id, nil
return accountIdURLResponse.Id, accountIdURLResponse.InfraVersion, nil
}

func (c *ClientImpl) getQueryParams(setStatements map[string]string) (map[string]string, error) {
Expand All @@ -150,7 +157,9 @@ func (c *ClientImpl) getQueryParams(setStatements map[string]string) (map[string
if len(c.AccountID) == 0 {
return nil, fmt.Errorf("Trying to run a query against system engine without account id defined")
}
params["account_id"] = c.AccountID
if _, ok := params["account_id"]; !ok {
params["account_id"] = c.AccountID
}
}
return params, nil
}
Expand All @@ -159,29 +168,62 @@ func (c *ClientImpl) getAccessToken() (string, error) {
return getAccessTokenServiceAccount(c.ClientID, c.ClientSecret, c.ApiEndpoint, c.UserAgent)
}

// GetEngineUrlAndDB returns engine URL and engine name based on engineName and accountId
func (c *ClientImpl) GetEngineUrlAndDB(ctx context.Context, engineName, databaseName string) (string, string, error) {
// Assume we are connected to a system engine in the beginning
c.ConnectedToSystemEngine = true
func (c *ClientImpl) getConnectionParametersV2(ctx context.Context, engineName, databaseName string) (string, map[string]string, error) {
engineURL := c.SystemEngineURL
parameters := make(map[string]string)
control := connectionControl{
updateParameters: func(key, value string) {
parameters[key] = value
},
setEngineURL: func(s string) {
engineURL = s
},
resetParameters: func() {},
}
if databaseName != "" {
if _, err := c.Query(ctx, engineURL, "USE DATABASE "+databaseName, parameters, control); err != nil {
return "", nil, err
}
}
if engineName != "" {
if _, err := c.Query(ctx, engineURL, "USE ENGINE "+engineName, parameters, control); err != nil {
return "", nil, err
}
}
return engineURL, parameters, nil
}

func (c *ClientImpl) getConnectionParametersV1(ctx context.Context, engineName, databaseName string) (string, map[string]string, error) {
// If engine name is empty, assume system engine
if len(engineName) == 0 {
return c.SystemEngineURL, databaseName, nil
return c.SystemEngineURL, map[string]string{"database": databaseName}, nil
}

engineUrl, status, dbName, err := c.getEngineUrlStatusDBByName(ctx, engineName, c.SystemEngineURL)
params := map[string]string{"database": dbName}
if err != nil {
return "", "", ConstructNestedError("error during getting engine info", err)
return "", params, ConstructNestedError("error during getting engine info", err)
}
if status != engineStatusRunning {
return "", "", fmt.Errorf("engine %s is not running", engineName)
return "", params, fmt.Errorf("engine %s is not running", engineName)
}
if len(dbName) == 0 {
return "", "", fmt.Errorf("engine %s not attached to any DB or you don't have permission to access its database", engineName)
return "", params, fmt.Errorf("engine %s not attached to any DB or you don't have permission to access its database", engineName)
}
if len(databaseName) != 0 && databaseName != dbName {
return "", "", fmt.Errorf("engine %s is not attached to database %s", engineName, databaseName)
return "", params, fmt.Errorf("engine %s is not attached to database %s", engineName, databaseName)
}
c.ConnectedToSystemEngine = false

return engineUrl, dbName, nil
return engineUrl, params, nil
}

// GetConnectionParameters returns engine URL and parameters based on engineName and databaseName
func (c *ClientImpl) GetConnectionParameters(ctx context.Context, engineName, databaseName string) (string, map[string]string, error) {
// Assume we are connected to a system engine in the beginning
c.ConnectedToSystemEngine = true
if c.AccountVersion == 2 {
return c.getConnectionParametersV2(ctx, engineName, databaseName)
}
return c.getConnectionParametersV1(ctx, engineName, databaseName)
}
93 changes: 76 additions & 17 deletions client_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

const outputFormat = "JSON_Compact"
const protocolVersionHeader = "Firebolt-Protocol-Version"
const protocolVersion = "2.0"
const protocolVersion = "2.1"

const updateParametersHeader = "Firebolt-Update-Parameters"
const updateEndpointHeader = "Firebolt-Update-Endpoint"
const resetSessionHeader = "Firebolt-Reset-Session"

var allowedUpdateParameters = []string{"database"}

type Client interface {
GetEngineUrlAndDB(ctx context.Context, engineName string, accountId string) (string, string, error)
Query(ctx context.Context, engineUrl, query string, parameters map[string]string, updateParameters func(string, string)) (*QueryResponse, error)
GetConnectionParameters(ctx context.Context, engineName string, databaseName string) (string, map[string]string, error)
Query(ctx context.Context, engineUrl, query string, parameters map[string]string, control connectionControl) (*QueryResponse, error)
}

type BaseClient struct {
Expand All @@ -41,8 +44,16 @@ type response struct {
err error
}

// connectionControl is a struct that holds methods for updating connection properties
// it's passed to Query method to allow it to update connection parameters and engine URL
type connectionControl struct {
updateParameters func(string, string)
resetParameters func()
setEngineURL func(string)
}

// Query sends a query to the engine URL and populates queryResponse, if query was successful
func (c *BaseClient) Query(ctx context.Context, engineUrl, query string, parameters map[string]string, updateParameters func(string, string)) (*QueryResponse, error) {
func (c *BaseClient) Query(ctx context.Context, engineUrl, query string, parameters map[string]string, control connectionControl) (*QueryResponse, error) {
infolog.Printf("Query engine '%s' with '%s'", engineUrl, query)

if c.parameterGetter == nil {
Expand All @@ -58,7 +69,7 @@ func (c *BaseClient) Query(ctx context.Context, engineUrl, query string, paramet
return nil, ConstructNestedError("error during query request", resp.err)
}

if err = processResponseHeaders(resp.headers, updateParameters); err != nil {
if err = c.processResponseHeaders(resp.headers, control); err != nil {
return nil, ConstructNestedError("error during processing response headers", err)
}

Expand Down Expand Up @@ -86,21 +97,69 @@ func contains(s []string, e string) bool {
return false
}

func processResponseHeaders(headers http.Header, updateParameters func(string, string)) error {
func handleUpdateParameters(updateParameters func(string, string), updateParametersRaw string) {
updateParametersPairs := strings.Split(updateParametersRaw, ",")
for _, parameter := range updateParametersPairs {
kv := strings.Split(parameter, "=")
if len(kv) != 2 {
infolog.Printf("Warning: invalid parameter assignment %s", parameter)
continue
}
if contains(allowedUpdateParameters, kv[0]) {
updateParameters(kv[0], kv[1])
} else {
infolog.Printf("Warning: received unknown update parameter %s", kv[0])
}
}
}

func splitEngineEndpoint(endpoint string) (string, url.Values, error) {
parsedUrl, err := url.Parse(endpoint)
if err != nil {
return "", nil, err
}
parameters, err := url.ParseQuery(parsedUrl.RawQuery)
if err != nil {
return "", nil, err
}
return parsedUrl.Host + parsedUrl.Path, parameters, nil
}

func (c *BaseClient) handleUpdateEndpoint(updateEndpointRaw string, control connectionControl) error {
// split URL containted into updateEndpointRaw into endpoint and parameters
// Update parameters and set client engine endpoint

corruptUrlError := errors.New("Failed to execute USE ENGINE command. Corrupt update endpoint. Contact support")
updateEndpoint, newParameters, err := splitEngineEndpoint(updateEndpointRaw)
if err != nil {
return corruptUrlError
}
if newParameters.Has("account_id") && newParameters.Get("account_id") != c.AccountID {
return errors.New("Failed to execute USE ENGINE command. Account parameter mismatch. Contact support")
}
// set engine URL as a full URL excluding query parameters
control.setEngineURL(updateEndpoint)
// update client parameters with new parameters
for k, v := range newParameters {
control.updateParameters(k, v[0])
}
return nil
}

func (c *BaseClient) processResponseHeaders(headers http.Header, control connectionControl) error {
if updateParametersRaw, ok := headers[updateParametersHeader]; ok {
updateParametersPairs := strings.Split(updateParametersRaw[0], ",")
for _, parameter := range updateParametersPairs {
kv := strings.Split(parameter, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid parameter assignment %s", parameter)
}
if contains(allowedUpdateParameters, kv[0]) {
updateParameters(kv[0], kv[1])
} else {
infolog.Printf("Warning: received unknown update parameter %s", kv[0])
}
handleUpdateParameters(control.updateParameters, updateParametersRaw[0])
}

if updateEndpoint, ok := headers[updateEndpointHeader]; ok {
if err := c.handleUpdateEndpoint(updateEndpoint[0], control); err != nil {
return err
}
}
if _, ok := headers[resetSessionHeader]; ok {
control.resetParameters()
}

return nil
}

Expand Down
10 changes: 5 additions & 5 deletions client_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ func testProtocolVersion(t *testing.T, clientFactory func(string) Client) {

client := clientFactory(server.URL)

_, _ = client.Query(context.TODO(), server.URL, "SELECT 1", map[string]string{}, func(key, value string) {
// Do nothing
})
_, _ = client.Query(context.TODO(), server.URL, "SELECT 1", map[string]string{}, connectionControl{})
if protocolVersionValue != protocolVersion {
t.Errorf("Did not set Protocol-Version value correctly on a query request")
}
Expand All @@ -46,8 +44,10 @@ func testUpdateParameters(t *testing.T, clientFactory func(string) Client) {
params := map[string]string{
"database": "db",
}
_, err := client.Query(context.TODO(), server.URL, "SELECT 1", params, func(key, value string) {
params[key] = value
_, err := client.Query(context.TODO(), server.URL, "SELECT 1", params, connectionControl{
updateParameters: func(key, value string) {
params[key] = value
},
})
if err != nil {
t.Errorf("Error during query execution with update parameters header in response %s", err)
Expand Down
Loading

0 comments on commit da68727

Please sign in to comment.