Skip to content

Commit

Permalink
Implement multistatements
Browse files Browse the repository at this point in the history
  • Loading branch information
yuryfirebolt committed Aug 18, 2022
1 parent 5e09485 commit 6ca2f85
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 43 deletions.
41 changes: 28 additions & 13 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,46 @@ func (c *fireboltConnection) Begin() (driver.Tx, error) {

// ExecContext sends the query to the engine and returns empty fireboltResult
func (c *fireboltConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
_, err := c.QueryContext(ctx, query, args)
_, err := c.queryContextInternal(ctx, query, args, false)
return &FireboltResult{}, err
}

// QueryContext sends the query to the engine and returns fireboltRows
func (c *fireboltConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return c.queryContextInternal(ctx, query, args, true)
}

func (c *fireboltConnection) queryContextInternal(ctx context.Context, query string, args []driver.NamedValue, isMultiStatementAllowed bool) (driver.Rows, error) {
query, err := prepareStatement(query, args)
if err != nil {
return nil, ConstructNestedError("error during preparing a statement", err)
}
queries, err := SplitStatements(query)
if err != nil {
return nil, ConstructNestedError("error during splitting query", err)
}
if len(queries) > 1 && !isMultiStatementAllowed {
return nil, fmt.Errorf("multistatement is not allowed")
}

if isSetStatement, err := processSetStatement(ctx, c, query); isSetStatement {
if err == nil {
return &fireboltRows{QueryResponse{}, 0}, nil
} else {
return nil, ConstructNestedError("statement recognized as an invalid set statement", err)
var rows fireboltRows
for _, query := range queries {
if isSetStatement, err := processSetStatement(ctx, c, query); isSetStatement {
if err == nil {
rows.response = append(rows.response, QueryResponse{})
continue
} else {
return &rows, ConstructNestedError("statement recognized as an invalid set statement", err)
}
}
}

queryResponse, err := c.client.Query(ctx, c.engineUrl, c.databaseName, query, c.setStatements)
if err != nil {
return nil, ConstructNestedError("error during query execution", err)
if response, err := c.client.Query(ctx, c.engineUrl, c.databaseName, query, c.setStatements); err != nil {
return &rows, ConstructNestedError("error during query execution", err)
} else {
rows.response = append(rows.response, *response)
}
}

return &fireboltRows{*queryResponse, 0}, nil
return &rows, nil
}

// processSetStatement is an internal function for checking whether query is a valid set statement
Expand All @@ -75,7 +89,8 @@ func processSetStatement(ctx context.Context, c *fireboltConnection, query strin
return false, nil
}

_, err = c.client.Query(ctx, c.engineUrl, c.databaseName, "SELECT 1", map[string]string{setKey: setValue})
_, err = c.client.Query(ctx, c.engineUrl, c.databaseName, "SELECT 1",
map[string]string{setKey: setValue, "advanced_mode": "1", "hidden_query": "1"})
if err == nil {
c.setStatements[setKey] = setValue
return true, nil
Expand Down
30 changes: 28 additions & 2 deletions connection_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestConnectionInsertQuery(t *testing.T) {
func TestConnectionQuery(t *testing.T) {
conn := fireboltConnection{clientMock, databaseMock, engineUrlMock, map[string]string{}}

sql := "SELECT 3213212 as \"const\", 2.3 as \"float\", 'some_text' as \"text\""
sql := "SELECT -3213212 as \"const\", 2.3 as \"float\", 'some_text' as \"text\""
rows, err := conn.QueryContext(context.TODO(), sql, nil)
if err != nil {
t.Errorf("firebolt statement failed with %v", err)
Expand All @@ -79,7 +79,7 @@ func TestConnectionQuery(t *testing.T) {
if err != nil {
t.Errorf("Next returned an error, but shouldn't")
}
assert(dest[0] == uint32(3213212), t, "dest[0] is not equal")
assert(dest[0] == int32(-3213212), t, "dest[0] is not equal")
assert(dest[1] == float64(2.3), t, "dest[1] is not equal")
assert(dest[2] == "some_text", t, "dest[2] is not equal")

Expand Down Expand Up @@ -141,3 +141,29 @@ func TestConnectionQueryDateTime64Type(t *testing.T) {
t.Errorf("values are not equal: %v and %v\n", dest[0], expected)
}
}

func TestConnectionMultipleStatement(t *testing.T) {
conn := fireboltConnection{clientMock, databaseMock, engineUrlMock, map[string]string{}}
if rows, err := conn.QueryContext(context.TODO(), "SELECT -1; SELECT -2", nil); err != nil {
t.Errorf("Query multistement returned err: %v", err)
} else {
dest := make([]driver.Value, 1)

err = rows.Next(dest)
assert(err == nil, t, "rows.Next returned an error")
assert(dest[0] == int32(-1), t, "results are not equal")

if nextResultSet, ok := rows.(driver.RowsNextResultSet); !ok {
t.Errorf("multistatement didn't return RowsNextResultSet")
} else {
assert(nextResultSet.HasNextResultSet(), t, "HasNextResultSet returned false")
assert(nextResultSet.NextResultSet() == nil, t, "NextResultSet returned an error")

err = rows.Next(dest)
assert(err == nil, t, "rows.Next returned an error")
assert(dest[0] == int32(-2), t, "results are not equal")

assert(!nextResultSet.HasNextResultSet(), t, "HasNextResultSet returned true")
}
}
}
35 changes: 27 additions & 8 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ const (
)

type fireboltRows struct {
response QueryResponse
cursorPosition int
response []QueryResponse
resultSetPosition int // Position of the result set (for multiple statements)
cursorPosition int // Position of the cursor in current result set
}

// Columns returns a list of Meta names in response
func (f *fireboltRows) Columns() []string {
numColumns := len(f.response.Meta)
numColumns := len(f.response[f.resultSetPosition].Meta)
result := make([]string, 0, numColumns)

for _, column := range f.response.Meta {
for _, column := range f.response[f.resultSetPosition].Meta {
result = append(result, column.Name)
}

Expand All @@ -46,20 +47,21 @@ func (f *fireboltRows) Columns() []string {

// Close makes the rows unusable
func (f *fireboltRows) Close() error {
f.cursorPosition = len(f.response.Data)
f.resultSetPosition = len(f.response) - 1
f.cursorPosition = len(f.response[f.resultSetPosition].Data)
return nil
}

// Next fetches the values of the next row, returns io.EOF if it was the end
func (f *fireboltRows) Next(dest []driver.Value) error {
if f.cursorPosition == len(f.response.Data) {
if f.cursorPosition == len(f.response[f.resultSetPosition].Data) {
return io.EOF
}

for i, column := range f.response.Meta {
for i, column := range f.response[f.resultSetPosition].Meta {
var err error
//log.Printf("Rows.Next: %s, %v", column.Type, f.response.Data[f.cursorPosition][i])
if dest[i], err = parseValue(column.Type, f.response.Data[f.cursorPosition][i]); err != nil {
if dest[i], err = parseValue(column.Type, f.response[f.resultSetPosition].Data[f.cursorPosition][i]); err != nil {
return ConstructNestedError("error during fetching Next result", err)
}
}
Expand All @@ -68,6 +70,23 @@ func (f *fireboltRows) Next(dest []driver.Value) error {
return nil
}

// HasNextResultSet reports whether there is another result set available
func (f *fireboltRows) HasNextResultSet() bool {
return len(f.response) > f.resultSetPosition+1
}

// NextResultSet advances to the next result set, if it is available, otherwise returns io.EOF
func (f *fireboltRows) NextResultSet() error {
if !f.HasNextResultSet() {
return io.EOF
}

f.cursorPosition = 0
f.resultSetPosition += 1

return nil
}

// checkTypeValue checks that val type could be changed to columnType
func checkTypeValue(columnType string, val interface{}) error {
switch strings.ToUpper(columnType) {
Expand Down
70 changes: 59 additions & 11 deletions rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ func assert(val bool, t *testing.T, err string) {
}
}

func mockRows() driver.Rows {
resultJson := "{" +
func mockRows(isMultiStatement bool) driver.RowsNextResultSet {
resultJson := []string{"{" +
"\"query\":{\"query_id\":\"16FF2A0300ECA753\"}," +
"\"meta\":[" +
" {\"name\":\"int_col\",\"type\":\"Nullable(Int32)\"}," +
Expand All @@ -41,19 +41,40 @@ func mockRows() driver.Rows {
" \"time_before_execution\":0.001251613," +
" \"time_to_execute\":0.000544098," +
" \"scanned_bytes_cache\":2003," +
" \"scanned_bytes_storage\":0}}"
" \"scanned_bytes_storage\":0}}", "" +
"{" +
"\"query\":{\"query_id\":\"16FF2A0300ECA753\"}," +
"\"meta\":[{\"name\":\"int_col\",\"type\":\"Nullable(Int32)\"}]," +
"\"data\":[[3], [null]]," +
"\"rows\":2," +
"\"statistics\":{" +
" \"elapsed\":0.001797702," +
" \"rows_read\":2," +
" \"bytes_read\":293," +
" \"time_before_execution\":0.001251613," +
" \"time_to_execute\":0.000544098," +
" \"scanned_bytes_cache\":2003," +
" \"scanned_bytes_storage\":0}}"}

var response QueryResponse
err := json.Unmarshal([]byte(resultJson), &response)
if err != nil {
panic("Error in test code")
var responses []QueryResponse
for i := 0; i < 2; i += 1 {
if i != 0 && !isMultiStatement {
break
}
var response QueryResponse
if err := json.Unmarshal([]byte(resultJson[i]), &response); err != nil {
panic("Error in test code")
} else {
responses = append(responses, response)
}
}
return &fireboltRows{response, 0}

return &fireboltRows{responses, 0, 0}
}

// TestRowsColumns checks, that correct column names are returned
func TestRowsColumns(t *testing.T) {
rows := mockRows()
rows := mockRows(false)

columnNames := []string{"int_col", "bigint_col", "float_col", "double_col", "text_col", "date_col", "timestamp_col", "bool_col", "array_col", "nested_array_col"}
if !reflect.DeepEqual(rows.Columns(), columnNames) {
Expand All @@ -63,7 +84,7 @@ func TestRowsColumns(t *testing.T) {

// TestRowsClose checks Close method, and inability to use rows afterward
func TestRowsClose(t *testing.T) {
rows := mockRows()
rows := mockRows(false)
if rows.Close() != nil {
t.Errorf("Closing rows was not successful")
}
Expand All @@ -76,7 +97,7 @@ func TestRowsClose(t *testing.T) {

// TestRowsNext check Next method
func TestRowsNext(t *testing.T) {
rows := mockRows()
rows := mockRows(false)
var dest = make([]driver.Value, 10)
err := rows.Next(dest)
loc, _ := time.LoadLocation("UTC")
Expand All @@ -103,4 +124,31 @@ func TestRowsNext(t *testing.T) {
assert(dest[4] == "text", t, "results not equal for string")

assert(io.EOF == rows.Next(dest), t, "Next should return io.EOF if no data available anymore")

assert(rows.HasNextResultSet() == false, t, "Has Next result set didn't return false")
assert(rows.NextResultSet() == io.EOF, t, "Next result set didn't return false")
}

// TestRowsNextSet check rows with multiple statements
func TestRowsNextSet(t *testing.T) {
rows := mockRows(true)

// check next result set functions
assert(rows.HasNextResultSet() == true, t, "HasNextResultSet returned false, but shouldn't")
assert(rows.NextResultSet() == nil, t, "NextResultSet returned an error, but shouldn't")
assert(rows.HasNextResultSet() == false, t, "HasNextResultSet returned true, but shouldn't")

// check columns of the next result set
assert(reflect.DeepEqual(rows.Columns(), []string{"int_col"}), t, "Columns of the next result set are incorrect")

// check values of the next result set
var dest = make([]driver.Value, 1)

assert(rows.Next(dest) == nil, t, "Next shouldn't return an error")
assert(dest[0] == int32(3), t, "results are not equal")

assert(rows.Next(dest) == nil, t, "Next shouldn't return an error")
assert(dest[0] == nil, t, "results are not equal")

assert(io.EOF == rows.Next(dest), t, "Next should return io.EOF if no data available anymore")
}
11 changes: 2 additions & 9 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package fireboltgosdk
import (
"context"
"database/sql/driver"
"fmt"
)

type Column struct {
Expand Down Expand Up @@ -42,18 +41,12 @@ func (stmt *fireboltStmt) NumInput() int {

// Exec calls ExecContext with dummy context
func (stmt *fireboltStmt) Exec(args []driver.Value) (driver.Result, error) {
if len(args) != 0 {
return nil, fmt.Errorf("Prepared statements are not implemented")
}
return stmt.ExecContext(context.TODO(), make([]driver.NamedValue, 0))
return stmt.ExecContext(context.TODO(), valueToNamedValue(args))
}

// Query calls QueryContext with dummy context
func (stmt *fireboltStmt) Query(args []driver.Value) (driver.Rows, error) {
if len(args) != 0 {
return nil, fmt.Errorf("Prepared statements are not implemented")
}
return stmt.QueryContext(context.TODO(), make([]driver.NamedValue, 0))
return stmt.QueryContext(context.TODO(), valueToNamedValue(args))
}

// QueryContext sends the query to the engine and returns fireboltRows
Expand Down
26 changes: 26 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ func prepareStatement(query string, params []driver.NamedValue) (string, error)
return query, nil
}

// SplitStatements split multiple statements into a list of statements
func SplitStatements(sql string) ([]string, error) {
var queries []string

for sql != "" {
var err error
var query string

query, sql, err = sqlparser.SplitStatement(sql)
if err != nil {
return nil, ConstructNestedError("error during splitting query", err)
}
queries = append(queries, query)
}

return queries, nil
}

func formatValue(value driver.Value) (string, error) {
switch v := value.(type) {
case string:
Expand Down Expand Up @@ -134,3 +152,11 @@ func ConstructUserAgentString() string {

return strings.TrimSpace(fmt.Sprintf("%s GoSDK/%s (Go %s; %s) %s", goClients, sdkVersion, runtime.Version(), osNameVersion, goDrivers))
}

func valueToNamedValue(args []driver.Value) []driver.NamedValue {
namedValues := make([]driver.NamedValue, 0, len(args))
for i, arg := range args {
namedValues = append(namedValues, driver.NamedValue{Ordinal: i, Value: arg})
}
return namedValues
}
Loading

0 comments on commit 6ca2f85

Please sign in to comment.