diff --git a/connection.go b/connection.go index 2e4dca5..4292c02 100644 --- a/connection.go +++ b/connection.go @@ -38,32 +38,46 @@ func (c *fireboltConnection) Begin() (driver.Tx, error) { // ExecContext sends the query to the engine and returns empty fireboltResult func (c *fireboltConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - _, err := c.QueryContext(ctx, query, args) + _, err := c.queryContextInternal(ctx, query, args, false) return &FireboltResult{}, err } // QueryContext sends the query to the engine and returns fireboltRows func (c *fireboltConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return c.queryContextInternal(ctx, query, args, true) +} +func (c *fireboltConnection) queryContextInternal(ctx context.Context, query string, args []driver.NamedValue, isMultiStatementAllowed bool) (driver.Rows, error) { query, err := prepareStatement(query, args) if err != nil { return nil, ConstructNestedError("error during preparing a statement", err) } + queries, err := SplitStatements(query) + if err != nil { + return nil, ConstructNestedError("error during splitting query", err) + } + if len(queries) > 1 && !isMultiStatementAllowed { + return nil, fmt.Errorf("multistatement is not allowed") + } - if isSetStatement, err := processSetStatement(ctx, c, query); isSetStatement { - if err == nil { - return &fireboltRows{QueryResponse{}, 0}, nil - } else { - return nil, ConstructNestedError("statement recognized as an invalid set statement", err) + var rows fireboltRows + for _, query := range queries { + if isSetStatement, err := processSetStatement(ctx, c, query); isSetStatement { + if err == nil { + rows.response = append(rows.response, QueryResponse{}) + continue + } else { + return &rows, ConstructNestedError("statement recognized as an invalid set statement", err) + } } - } - queryResponse, err := c.client.Query(ctx, c.engineUrl, c.databaseName, query, c.setStatements) - if err != nil { - return nil, ConstructNestedError("error during query execution", err) + if response, err := c.client.Query(ctx, c.engineUrl, c.databaseName, query, c.setStatements); err != nil { + return &rows, ConstructNestedError("error during query execution", err) + } else { + rows.response = append(rows.response, *response) + } } - - return &fireboltRows{*queryResponse, 0}, nil + return &rows, nil } // processSetStatement is an internal function for checking whether query is a valid set statement @@ -75,7 +89,8 @@ func processSetStatement(ctx context.Context, c *fireboltConnection, query strin return false, nil } - _, err = c.client.Query(ctx, c.engineUrl, c.databaseName, "SELECT 1", map[string]string{setKey: setValue}) + _, err = c.client.Query(ctx, c.engineUrl, c.databaseName, "SELECT 1", + map[string]string{setKey: setValue, "advanced_mode": "1", "hidden_query": "1"}) if err == nil { c.setStatements[setKey] = setValue return true, nil diff --git a/connection_integration_test.go b/connection_integration_test.go index edfefd2..1a431ee 100644 --- a/connection_integration_test.go +++ b/connection_integration_test.go @@ -63,7 +63,7 @@ func TestConnectionInsertQuery(t *testing.T) { func TestConnectionQuery(t *testing.T) { conn := fireboltConnection{clientMock, databaseMock, engineUrlMock, map[string]string{}} - sql := "SELECT 3213212 as \"const\", 2.3 as \"float\", 'some_text' as \"text\"" + sql := "SELECT -3213212 as \"const\", 2.3 as \"float\", 'some_text' as \"text\"" rows, err := conn.QueryContext(context.TODO(), sql, nil) if err != nil { t.Errorf("firebolt statement failed with %v", err) @@ -79,7 +79,7 @@ func TestConnectionQuery(t *testing.T) { if err != nil { t.Errorf("Next returned an error, but shouldn't") } - assert(dest[0] == uint32(3213212), t, "dest[0] is not equal") + assert(dest[0] == int32(-3213212), t, "dest[0] is not equal") assert(dest[1] == float64(2.3), t, "dest[1] is not equal") assert(dest[2] == "some_text", t, "dest[2] is not equal") @@ -141,3 +141,29 @@ func TestConnectionQueryDateTime64Type(t *testing.T) { t.Errorf("values are not equal: %v and %v\n", dest[0], expected) } } + +func TestConnectionMultipleStatement(t *testing.T) { + conn := fireboltConnection{clientMock, databaseMock, engineUrlMock, map[string]string{}} + if rows, err := conn.QueryContext(context.TODO(), "SELECT -1; SELECT -2", nil); err != nil { + t.Errorf("Query multistement returned err: %v", err) + } else { + dest := make([]driver.Value, 1) + + err = rows.Next(dest) + assert(err == nil, t, "rows.Next returned an error") + assert(dest[0] == int32(-1), t, "results are not equal") + + if nextResultSet, ok := rows.(driver.RowsNextResultSet); !ok { + t.Errorf("multistatement didn't return RowsNextResultSet") + } else { + assert(nextResultSet.HasNextResultSet(), t, "HasNextResultSet returned false") + assert(nextResultSet.NextResultSet() == nil, t, "NextResultSet returned an error") + + err = rows.Next(dest) + assert(err == nil, t, "rows.Next returned an error") + assert(dest[0] == int32(-2), t, "results are not equal") + + assert(!nextResultSet.HasNextResultSet(), t, "HasNextResultSet returned true") + } + } +} diff --git a/rows.go b/rows.go index d2e0d80..192c044 100644 --- a/rows.go +++ b/rows.go @@ -28,16 +28,17 @@ const ( ) type fireboltRows struct { - response QueryResponse - cursorPosition int + response []QueryResponse + resultSetPosition int // Position of the result set (for multiple statements) + cursorPosition int // Position of the cursor in current result set } // Columns returns a list of Meta names in response func (f *fireboltRows) Columns() []string { - numColumns := len(f.response.Meta) + numColumns := len(f.response[f.resultSetPosition].Meta) result := make([]string, 0, numColumns) - for _, column := range f.response.Meta { + for _, column := range f.response[f.resultSetPosition].Meta { result = append(result, column.Name) } @@ -46,20 +47,21 @@ func (f *fireboltRows) Columns() []string { // Close makes the rows unusable func (f *fireboltRows) Close() error { - f.cursorPosition = len(f.response.Data) + f.resultSetPosition = len(f.response) - 1 + f.cursorPosition = len(f.response[f.resultSetPosition].Data) return nil } // Next fetches the values of the next row, returns io.EOF if it was the end func (f *fireboltRows) Next(dest []driver.Value) error { - if f.cursorPosition == len(f.response.Data) { + if f.cursorPosition == len(f.response[f.resultSetPosition].Data) { return io.EOF } - for i, column := range f.response.Meta { + for i, column := range f.response[f.resultSetPosition].Meta { var err error //log.Printf("Rows.Next: %s, %v", column.Type, f.response.Data[f.cursorPosition][i]) - if dest[i], err = parseValue(column.Type, f.response.Data[f.cursorPosition][i]); err != nil { + if dest[i], err = parseValue(column.Type, f.response[f.resultSetPosition].Data[f.cursorPosition][i]); err != nil { return ConstructNestedError("error during fetching Next result", err) } } @@ -68,6 +70,23 @@ func (f *fireboltRows) Next(dest []driver.Value) error { return nil } +// HasNextResultSet reports whether there is another result set available +func (f *fireboltRows) HasNextResultSet() bool { + return len(f.response) > f.resultSetPosition+1 +} + +// NextResultSet advances to the next result set, if it is available, otherwise returns io.EOF +func (f *fireboltRows) NextResultSet() error { + if !f.HasNextResultSet() { + return io.EOF + } + + f.cursorPosition = 0 + f.resultSetPosition += 1 + + return nil +} + // checkTypeValue checks that val type could be changed to columnType func checkTypeValue(columnType string, val interface{}) error { switch strings.ToUpper(columnType) { diff --git a/rows_test.go b/rows_test.go index e41b0f7..7141af1 100644 --- a/rows_test.go +++ b/rows_test.go @@ -15,8 +15,8 @@ func assert(val bool, t *testing.T, err string) { } } -func mockRows() driver.Rows { - resultJson := "{" + +func mockRows(isMultiStatement bool) driver.RowsNextResultSet { + resultJson := []string{"{" + "\"query\":{\"query_id\":\"16FF2A0300ECA753\"}," + "\"meta\":[" + " {\"name\":\"int_col\",\"type\":\"Nullable(Int32)\"}," + @@ -41,19 +41,40 @@ func mockRows() driver.Rows { " \"time_before_execution\":0.001251613," + " \"time_to_execute\":0.000544098," + " \"scanned_bytes_cache\":2003," + - " \"scanned_bytes_storage\":0}}" + " \"scanned_bytes_storage\":0}}", "" + + "{" + + "\"query\":{\"query_id\":\"16FF2A0300ECA753\"}," + + "\"meta\":[{\"name\":\"int_col\",\"type\":\"Nullable(Int32)\"}]," + + "\"data\":[[3], [null]]," + + "\"rows\":2," + + "\"statistics\":{" + + " \"elapsed\":0.001797702," + + " \"rows_read\":2," + + " \"bytes_read\":293," + + " \"time_before_execution\":0.001251613," + + " \"time_to_execute\":0.000544098," + + " \"scanned_bytes_cache\":2003," + + " \"scanned_bytes_storage\":0}}"} - var response QueryResponse - err := json.Unmarshal([]byte(resultJson), &response) - if err != nil { - panic("Error in test code") + var responses []QueryResponse + for i := 0; i < 2; i += 1 { + if i != 0 && !isMultiStatement { + break + } + var response QueryResponse + if err := json.Unmarshal([]byte(resultJson[i]), &response); err != nil { + panic("Error in test code") + } else { + responses = append(responses, response) + } } - return &fireboltRows{response, 0} + + return &fireboltRows{responses, 0, 0} } // TestRowsColumns checks, that correct column names are returned func TestRowsColumns(t *testing.T) { - rows := mockRows() + rows := mockRows(false) columnNames := []string{"int_col", "bigint_col", "float_col", "double_col", "text_col", "date_col", "timestamp_col", "bool_col", "array_col", "nested_array_col"} if !reflect.DeepEqual(rows.Columns(), columnNames) { @@ -63,7 +84,7 @@ func TestRowsColumns(t *testing.T) { // TestRowsClose checks Close method, and inability to use rows afterward func TestRowsClose(t *testing.T) { - rows := mockRows() + rows := mockRows(false) if rows.Close() != nil { t.Errorf("Closing rows was not successful") } @@ -76,7 +97,7 @@ func TestRowsClose(t *testing.T) { // TestRowsNext check Next method func TestRowsNext(t *testing.T) { - rows := mockRows() + rows := mockRows(false) var dest = make([]driver.Value, 10) err := rows.Next(dest) loc, _ := time.LoadLocation("UTC") @@ -103,4 +124,31 @@ func TestRowsNext(t *testing.T) { assert(dest[4] == "text", t, "results not equal for string") assert(io.EOF == rows.Next(dest), t, "Next should return io.EOF if no data available anymore") + + assert(rows.HasNextResultSet() == false, t, "Has Next result set didn't return false") + assert(rows.NextResultSet() == io.EOF, t, "Next result set didn't return false") +} + +// TestRowsNextSet check rows with multiple statements +func TestRowsNextSet(t *testing.T) { + rows := mockRows(true) + + // check next result set functions + assert(rows.HasNextResultSet() == true, t, "HasNextResultSet returned false, but shouldn't") + assert(rows.NextResultSet() == nil, t, "NextResultSet returned an error, but shouldn't") + assert(rows.HasNextResultSet() == false, t, "HasNextResultSet returned true, but shouldn't") + + // check columns of the next result set + assert(reflect.DeepEqual(rows.Columns(), []string{"int_col"}), t, "Columns of the next result set are incorrect") + + // check values of the next result set + var dest = make([]driver.Value, 1) + + assert(rows.Next(dest) == nil, t, "Next shouldn't return an error") + assert(dest[0] == int32(3), t, "results are not equal") + + assert(rows.Next(dest) == nil, t, "Next shouldn't return an error") + assert(dest[0] == nil, t, "results are not equal") + + assert(io.EOF == rows.Next(dest), t, "Next should return io.EOF if no data available anymore") } diff --git a/statement.go b/statement.go index 590ad99..bcd3ba2 100644 --- a/statement.go +++ b/statement.go @@ -3,7 +3,6 @@ package fireboltgosdk import ( "context" "database/sql/driver" - "fmt" ) type Column struct { @@ -42,18 +41,12 @@ func (stmt *fireboltStmt) NumInput() int { // Exec calls ExecContext with dummy context func (stmt *fireboltStmt) Exec(args []driver.Value) (driver.Result, error) { - if len(args) != 0 { - return nil, fmt.Errorf("Prepared statements are not implemented") - } - return stmt.ExecContext(context.TODO(), make([]driver.NamedValue, 0)) + return stmt.ExecContext(context.TODO(), valueToNamedValue(args)) } // Query calls QueryContext with dummy context func (stmt *fireboltStmt) Query(args []driver.Value) (driver.Rows, error) { - if len(args) != 0 { - return nil, fmt.Errorf("Prepared statements are not implemented") - } - return stmt.QueryContext(context.TODO(), make([]driver.NamedValue, 0)) + return stmt.QueryContext(context.TODO(), valueToNamedValue(args)) } // QueryContext sends the query to the engine and returns fireboltRows diff --git a/utils.go b/utils.go index ad72504..a4a343d 100644 --- a/utils.go +++ b/utils.go @@ -78,6 +78,24 @@ func prepareStatement(query string, params []driver.NamedValue) (string, error) return query, nil } +// SplitStatements split multiple statements into a list of statements +func SplitStatements(sql string) ([]string, error) { + var queries []string + + for sql != "" { + var err error + var query string + + query, sql, err = sqlparser.SplitStatement(sql) + if err != nil { + return nil, ConstructNestedError("error during splitting query", err) + } + queries = append(queries, query) + } + + return queries, nil +} + func formatValue(value driver.Value) (string, error) { switch v := value.(type) { case string: @@ -134,3 +152,11 @@ func ConstructUserAgentString() string { return strings.TrimSpace(fmt.Sprintf("%s GoSDK/%s (Go %s; %s) %s", goClients, sdkVersion, runtime.Version(), osNameVersion, goDrivers)) } + +func valueToNamedValue(args []driver.Value) []driver.NamedValue { + namedValues := make([]driver.NamedValue, 0, len(args)) + for i, arg := range args { + namedValues = append(namedValues, driver.NamedValue{Ordinal: i, Value: arg}) + } + return namedValues +} diff --git a/utils_test.go b/utils_test.go index 02f0f61..8653cbc 100644 --- a/utils_test.go +++ b/utils_test.go @@ -3,6 +3,7 @@ package fireboltgosdk import ( "database/sql/driver" "os" + "reflect" "strings" "testing" "time" @@ -130,3 +131,34 @@ func TestConstructUserAgentString(t *testing.T) { os.Unsetenv("FIREBOLT_GO_DRIVERS") os.Unsetenv("FIREBOLT_GO_CLIENTS") } + +func runSplitStatement(t *testing.T, value string, expected []string) { + stmts, err := SplitStatements(value) + if err != nil { + t.Errorf("SplitStatements return an error for: %v", value) + } + + if !reflect.DeepEqual(stmts, expected) { + t.Errorf("SplitStatements returned and expected are not equal: %v != %v", stmts, expected) + } +} + +func TestSplitStatements(t *testing.T) { + runSplitStatement(t, "SELECT 1; SELECT 2;", []string{"SELECT 1", " SELECT 2"}) + runSplitStatement(t, "SELECT 1;", []string{"SELECT 1"}) + runSplitStatement(t, "SELECT 1", []string{"SELECT 1"}) + runSplitStatement(t, "SELECT 1; ; ; ; ", []string{"SELECT 1", " ", " ", " ", " "}) + + runSplitStatement(t, "SET advanced_mode=1; SELECT 2 /*some ; comment*/", []string{"SET advanced_mode=1", " SELECT 2 /*some ; comment*/"}) + runSplitStatement(t, "SET advanced_mode=';'; SELECT 2 /*some ; comment*/", []string{"SET advanced_mode=';'", " SELECT 2 /*some ; comment*/"}) + runSplitStatement(t, "SELECT 1; SELECT 2; SELECT 3; SELECT 4; SELECT 5; SELECT 6", []string{"SELECT 1", " SELECT 2", " SELECT 3", " SELECT 4", " SELECT 5", " SELECT 6"}) +} + +func TestValueToNamedValue(t *testing.T) { + assert(len(valueToNamedValue([]driver.Value{})) == 0, t, "valueToNamedValue of empty array is wrong") + + namedValues := valueToNamedValue([]driver.Value{2, "string"}) + assert(len(namedValues) == 2, t, "len of namedValues is wrong") + assert(namedValues[0].Value == 2, t, "namedValues value is wrong") + assert(namedValues[1].Value == "string", t, "namedValues value is wrong") +}