Skip to content

Commit

Permalink
feat: Fir 28441 support use database statement (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Jan 10, 2024
1 parent 42084ee commit 36a47ee
Show file tree
Hide file tree
Showing 18 changed files with 886 additions and 492 deletions.
16 changes: 8 additions & 8 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ func getAccessTokenUsernamePassword(username string, password string, apiEndpoin
return "", err
}
infolog.Printf("Start authentication into '%s' using '%s'", apiEndpoint, loginUrl)
resp, err, _ := request(context.TODO(), "", "POST", apiEndpoint+loginUrl, userAgent, nil, body, contentType)
if err != nil {
return "", ConstructNestedError("authentication request failed", err)
resp := request(requestParameters{context.TODO(), "", "POST", apiEndpoint + loginUrl, userAgent, nil, body, contentType})
if resp.err != nil {
return "", ConstructNestedError("authentication request failed", resp.err)
}

var authResp AuthenticationResponse
if err = jsonStrictUnmarshall(resp, &authResp); err != nil {
if err = jsonStrictUnmarshall(resp.data, &authResp); err != nil {
return "", ConstructNestedError("failed to unmarshal authentication response with error", err)
}
infolog.Printf("Authentication was successful")
Expand Down Expand Up @@ -116,13 +116,13 @@ func getAccessTokenServiceAccount(clientId string, clientSecret string, apiEndpo
return "", ConstructNestedError("error building auth endpoint", err)
}
infolog.Printf("Start authentication into '%s' using '%s'", authEndpoint, loginUrl)
resp, err, _ := request(context.TODO(), "", "POST", authEndpoint+loginUrl, userAgent, nil, body, contentType)
if err != nil {
return "", ConstructNestedError("authentication request failed", err)
resp := request(requestParameters{context.TODO(), "", "POST", authEndpoint + loginUrl, userAgent, nil, body, contentType})
if resp.err != nil {
return "", ConstructNestedError("authentication request failed", resp.err)
}

var authResp AuthenticationResponse
if err = jsonStrictUnmarshall(resp, &authResp); err != nil {
if err = jsonStrictUnmarshall(resp.data, &authResp); err != nil {
return "", ConstructNestedError("failed to unmarshal authentication response with error", err)
}
infolog.Printf("Authentication was successful")
Expand Down
33 changes: 16 additions & 17 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func MakeClient(settings *fireboltSettings, apiEndpoint string) (*ClientImpl, er
func (c *ClientImpl) getEngineUrlStatusDBByName(ctx context.Context, engineName string, systemEngineUrl string) (string, string, string, error) {
infolog.Printf("Get info for engine '%s'", engineName)
engineSQL := fmt.Sprintf(engineInfoSQL, engineName)
queryRes, err := c.Query(ctx, systemEngineUrl, "", engineSQL, make(map[string]string))
queryRes, err := c.Query(ctx, systemEngineUrl, engineSQL, make(map[string]string), func(key, value string) {
// No need to support set statements for engine info query
})
if err != nil {
return "", "", "", ConstructNestedError("error executing engine info sql query", err)
}
Expand Down Expand Up @@ -94,17 +96,17 @@ func (c *ClientImpl) getSystemEngineURL(ctx context.Context, accountName string)

url := fmt.Sprintf(c.ApiEndpoint+EngineUrlByAccountName, accountName)

response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "")
if err_code == 404 {
resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
return "", fmt.Errorf(accountError, accountName)
}
if err != nil {
return "", ConstructNestedError("error during system engine url http request", err)
if resp.err != nil {
return "", ConstructNestedError("error during system engine url http request", resp.err)
}

var systemEngineURLResponse SystemEngineURLResponse
if err = json.Unmarshal(response, &systemEngineURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling system engine URL response", errors.New(string(response)))
if err := json.Unmarshal(resp.data, &systemEngineURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling system engine URL response", errors.New(string(resp.data)))
}

return systemEngineURLResponse.EngineUrl, nil
Expand All @@ -120,29 +122,26 @@ func (c *ClientImpl) getAccountID(ctx context.Context, accountName string) (stri

url := fmt.Sprintf(c.ApiEndpoint+AccountIdByAccountName, accountName)

response, err, err_code := c.request(ctx, "GET", url, make(map[string]string), "")
if err_code == 404 {
resp := c.request(ctx, "GET", url, make(map[string]string), "")
if resp.statusCode == 404 {
return "", fmt.Errorf(accountError, accountName)
}
if err != nil {
return "", ConstructNestedError("error during account id resolution http request", err)
if resp.err != nil {
return "", ConstructNestedError("error during account id resolution http request", resp.err)
}

var accountIdURLResponse AccountIdURLResponse
if err = json.Unmarshal(response, &accountIdURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(response)))
if err := json.Unmarshal(resp.data, &accountIdURLResponse); err != nil {
return "", ConstructNestedError("error during unmarshalling account id resolution URL response", errors.New(string(resp.data)))
}

infolog.Printf("Resolved account %s to id %s", accountName, accountIdURLResponse.Id)

return accountIdURLResponse.Id, nil
}

func (c *ClientImpl) getQueryParams(databaseName string, setStatements map[string]string) (map[string]string, error) {
func (c *ClientImpl) getQueryParams(setStatements map[string]string) (map[string]string, error) {
params := map[string]string{"output_format": outputFormat}
if len(databaseName) > 0 {
params["database"] = databaseName
}
for setKey, setValue := range setStatements {
params[setKey] = setValue
}
Expand Down
131 changes: 95 additions & 36 deletions client_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"net/http"
"strings"
)

const outputFormat = "JSON_Compact"
const protocolVersionHeader = "Firebolt-Protocol-Version"
const protocolVersion = "2.0"

const updateParametersHeader = "Firebolt-Update-Parameters"

var allowedUpdateParameters = []string{"database"}

type Client interface {
GetEngineUrlAndDB(ctx context.Context, engineName string, accountId string) (string, string, error)
Query(ctx context.Context, engineUrl, databaseName, query string, setStatements map[string]string) (*QueryResponse, error)
Query(ctx context.Context, engineUrl, query string, parameters map[string]string, updateParameters func(string, string)) (*QueryResponse, error)
}

type BaseClient struct {
Expand All @@ -24,69 +30,106 @@ type BaseClient struct {
AccountID string
ApiEndpoint string
UserAgent string
parameterGetter func(string, map[string]string) (map[string]string, error)
parameterGetter func(map[string]string) (map[string]string, error)
accessTokenGetter func() (string, error)
}

type response struct {
data []byte
statusCode int
headers http.Header
err error
}

// Query sends a query to the engine URL and populates queryResponse, if query was successful
func (c *BaseClient) Query(ctx context.Context, engineUrl, databaseName, query string, setStatements map[string]string) (*QueryResponse, error) {
func (c *BaseClient) Query(ctx context.Context, engineUrl, query string, parameters map[string]string, updateParameters func(string, string)) (*QueryResponse, error) {
infolog.Printf("Query engine '%s' with '%s'", engineUrl, query)

if c.parameterGetter == nil {
return nil, errors.New("parameterGetter is not set")
}
params, err := c.parameterGetter(databaseName, setStatements)
params, err := c.parameterGetter(parameters)
if err != nil {
return nil, err
}

response, err, _ := c.request(ctx, "POST", engineUrl, params, query)
if err != nil {
return nil, ConstructNestedError("error during query request", err)
resp := c.request(ctx, "POST", engineUrl, params, query)
if resp.err != nil {
return nil, ConstructNestedError("error during query request", resp.err)
}

if err = processResponseHeaders(resp.headers, updateParameters); err != nil {
return nil, ConstructNestedError("error during processing response headers", err)
}

var queryResponse QueryResponse
if len(response) == 0 {
if len(resp.data) == 0 {
// response could be empty, which doesn't mean it is an error
return &queryResponse, nil
}

if err = json.Unmarshal(response, &queryResponse); err != nil {
return nil, ConstructNestedError("wrong response", errors.New(string(response)))
if err = json.Unmarshal(resp.data, &queryResponse); err != nil {
return nil, ConstructNestedError("wrong response", errors.New(string(resp.data)))
}

infolog.Printf("Query was successful")
return &queryResponse, nil
}

// check whether a string is present in a slice
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}

func processResponseHeaders(headers http.Header, updateParameters func(string, string)) error {
if updateParametersRaw, ok := headers[updateParametersHeader]; ok {
updateParametersPairs := strings.Split(updateParametersRaw[0], ",")
for _, parameter := range updateParametersPairs {
kv := strings.Split(parameter, "=")
if len(kv) != 2 {
return fmt.Errorf("invalid parameter assignment %s", parameter)
}
if contains(allowedUpdateParameters, kv[0]) {
updateParameters(kv[0], kv[1])
} else {
infolog.Printf("Warning: received unknown update parameter %s", kv[0])
}
}
}
return nil
}

// request fetches an access token from the cache or re-authenticate when the access token is not available in the cache
// and sends a request using that token
func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) ([]byte, error, int) {
func (c *BaseClient) request(ctx context.Context, method string, url string, params map[string]string, bodyStr string) response {
var err error

if c.accessTokenGetter == nil {
return nil, errors.New("accessTokenGetter is not set"), 0
return response{nil, 0, nil, errors.New("accessTokenGetter is not set")}
}

accessToken, err := c.accessTokenGetter()
if err != nil {
return nil, ConstructNestedError("error while getting access token", err), 0
return response{nil, 0, nil, ConstructNestedError("error while getting access token", err)}
}
var response []byte
var responseCode int
response, err, responseCode = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON)
if responseCode == http.StatusUnauthorized {
resp := request(requestParameters{ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON})
if resp.statusCode == http.StatusUnauthorized {
deleteAccessTokenFromCache(c.ClientID, c.ApiEndpoint)

// Refreshing the access token as it is expired
accessToken, err = c.accessTokenGetter()
if err != nil {
return nil, ConstructNestedError("error while refreshing access token", err), 0
return response{nil, 0, nil, ConstructNestedError("error while getting access token", err)}
}
// Trying to send the same request again now that the access token has been refreshed
response, err, responseCode = request(ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON)
resp = request(requestParameters{ctx, accessToken, method, url, c.UserAgent, params, bodyStr, ContentTypeJSON})
}
return response, err, responseCode
return resp
}

// makeCanonicalUrl checks whether url starts with https:// and if not prepends it
Expand Down Expand Up @@ -118,27 +161,43 @@ func checkErrorResponse(response []byte) error {
return nil
}

// Collect arguments for request function
type requestParameters struct {
ctx context.Context
accessToken string
method string
url string
userAgent string
params map[string]string
bodyStr string
contentType string
}

// request sends a request using "POST" or "GET" method on a specified url
// additionally it passes the parameters and a bodyStr as a payload
// if accessToken is passed, it is used for authorization
// returns response and an error
func request(ctx context.Context, accessToken string, method string, url string, userAgent string, params map[string]string, bodyStr string, contentType string) ([]byte, error, int) {
req, _ := http.NewRequestWithContext(ctx, method, makeCanonicalUrl(url), strings.NewReader(bodyStr))
func request(
reqParams requestParameters) response {
req, _ := http.NewRequestWithContext(reqParams.ctx, reqParams.method, makeCanonicalUrl(reqParams.url), strings.NewReader(reqParams.bodyStr))

// adding sdk usage tracking
req.Header.Set("User-Agent", userAgent)
req.Header.Set("User-Agent", reqParams.userAgent)

// add protocol version header
req.Header.Set(protocolVersionHeader, protocolVersion)

if len(accessToken) > 0 {
var bearer = "Bearer " + accessToken
if len(reqParams.accessToken) > 0 {
var bearer = "Bearer " + reqParams.accessToken
req.Header.Add("Authorization", bearer)
}

if len(contentType) > 0 {
req.Header.Set("Content-Type", contentType)
if len(reqParams.contentType) > 0 {
req.Header.Set("Content-Type", reqParams.contentType)
}

q := req.URL.Query()
for key, value := range params {
for key, value := range reqParams.params {
q.Add(key, value)
}
req.URL.RawQuery = q.Encode()
Expand All @@ -147,29 +206,29 @@ func request(ctx context.Context, accessToken string, method string, url string,
resp, err := client.Do(req)
if err != nil {
infolog.Println(err)
return nil, ConstructNestedError("error during a request execution", err), 0
return response{nil, 0, nil, ConstructNestedError("error during a request execution", err)}
}

defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
infolog.Println(err)
return nil, ConstructNestedError("error during reading a request response", err), 0
return response{nil, 0, nil, ConstructNestedError("error during reading a request response", err)}
}

if !(resp.StatusCode >= 200 && resp.StatusCode < 300) {
if err = checkErrorResponse(body); err != nil {
return nil, ConstructNestedError("request returned an error", err), resp.StatusCode
return response{nil, resp.StatusCode, nil, ConstructNestedError("request returned an error", err)}
}
if resp.StatusCode == 500 {
// this is a database error
return nil, fmt.Errorf("%s", string(body)), resp.StatusCode
return response{nil, resp.StatusCode, nil, fmt.Errorf("%s", string(body))}
}
return nil, fmt.Errorf("request returned non ok status code: %d, %s", resp.StatusCode, string(body)), resp.StatusCode
return response{nil, resp.StatusCode, nil, fmt.Errorf("request returned non ok status code: %d, %s", resp.StatusCode, string(body))}
}

return body, nil, resp.StatusCode
return response{body, resp.StatusCode, resp.Header, nil}
}

// jsonStrictUnmarshall unmarshalls json into object, and returns an error
Expand Down
Loading

0 comments on commit 36a47ee

Please sign in to comment.