diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 35b8199..6e4b5fe 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -19,7 +19,7 @@ jobs: with: repository: 'taosdata/TDengine' path: 'TDengine' - ref: '3.0' + ref: 'main' - name: install TDengine run: | diff --git a/.gitignore b/.gitignore index a09c56d..3a56068 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /.idea +vendor \ No newline at end of file diff --git a/README-CN.md b/README-CN.md index 83f8d98..6757da6 100644 --- a/README-CN.md +++ b/README-CN.md @@ -4,7 +4,7 @@ [English](README.md) | 简体中文 -[TDengine]提供了 GO 驱动程序 [`taosSql`][driver-go],实现了 GO 语言的内置数据库操作接口 `database/sql/driver`。 +[TDengine] 提供了 GO 驱动程序 [`taosSql`][driver-go],实现了 GO 语言的内置数据库操作接口 `database/sql/driver`。 ## 提示 @@ -13,6 +13,10 @@ v2 与 v3 版本不兼容,与 TDengine 版本对应如下: | **driver-go 版本** | **TDengine 版本** | |------------------|-----------------| | v3.0.0 | 3.0.0.0+ | +| v3.0.1 | 3.0.0.0+ | +| v3.0.3 | 3.0.1.5+ | +| v3.0.4 | 3.0.2.2+ | +| v3.1.0 | 3.0.2.2+ | ## 安装 @@ -126,31 +130,31 @@ func main() { 创建消费: ```go -func NewConsumer(conf *Config) (*Consumer, error) +func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) ``` -订阅: +订阅单个主题: ```go -func (c *Consumer) Subscribe(topics []string) error +func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error ``` -轮询消息: +订阅: ```go -func (c *Consumer) Poll(timeout time.Duration) (*Result, error) +func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error ``` -提交消息: +轮询消息: ```go -func (c *Consumer) Commit(ctx context.Context, message unsafe.Pointer) error +func (c *Consumer) Poll(timeoutMs int) tmq.Event ``` -释放消息: +提交消息: ```go -func (c *Consumer) FreeMessage(message unsafe.Pointer) +func (c *Consumer) Commit() ([]tmq.TopicPartition, error) ``` 取消订阅: @@ -318,7 +322,9 @@ import ( DSN 格式为: -```数据库用户名:数据库密码@连接方式(域名或ip:端口)/[数据库][?参数]``` +```text +数据库用户名:数据库密码@连接方式(域名或 ip:端口)/[数据库][?参数] +``` 样例: @@ -383,6 +389,242 @@ func main() { } ``` +## websocket 实现 `database/sql` 标准接口 + +通过 websocket 方式实现 `database/sql` 接口,使用方法简单示例如下: + +```go +package main + +import ( + "database/sql" + "fmt" + "time" + + _ "github.com/taosdata/driver-go/v3/taosWS" +) + +func main() { + var taosDSN = "root:taosdata@ws(localhost:6041)/" + taos, err := sql.Open("taosWS", taosDSN) + if err != nil { + fmt.Println("failed to connect TDengine, err:", err) + return + } + defer taos.Close() + taos.Exec("create database if not exists test") + taos.Exec("create table if not exists test.tb1 (ts timestamp, a int)") + _, err = taos.Exec("insert into test.tb1 values(now, 0)(now+1s,1)(now+2s,2)(now+3s,3)") + if err != nil { + fmt.Println("failed to insert, err:", err) + return + } + rows, err := taos.Query("select * from test.tb1") + if err != nil { + fmt.Println("failed to select from table, err:", err) + return + } + + defer rows.Close() + for rows.Next() { + var r struct { + ts time.Time + a int + } + err := rows.Scan(&r.ts, &r.a) + if err != nil { + fmt.Println("scan error:\n", err) + return + } + fmt.Println(r.ts, r.a) + } +} +``` + +### 使用 + +引入 + +```go +import ( + "database/sql" + _ "github.com/taosdata/driver-go/v3/taosWS" +) +``` + +`sql.Open` 的 driverName 为 `taosWS` + +DSN 格式为: + +```text +数据库用户名:数据库密码@连接方式(域名或 ip:端口)/[数据库][?参数] +``` + +样例: + +```root:taosdata@ws(localhost:6041)/test?writeTimeout=10s&readTimeout=10m``` + +参数: + +- `writeTimeout` 通过 websocket 发送数据的超时时间。 +- `readTimeout` 通过 websocket 接收响应数据的超时时间。 + +## 通过 websocket 使用 tmq + +通过 websocket 方式使用 tmq。服务端需要启动 taoAdapter。 + +### 配置相关 API + +- `func NewConfig(url string, chanLength uint) *Config` + + 创建配置项,传入 websocket 地址和发送管道长度。 + +- `func (c *Config) SetConnectUser(user string) error` + + 设置用户名。 + +- `func (c *Config) SetConnectPass(pass string) error` + + 设置密码。 + +- `func (c *Config) SetClientID(clientID string) error` + + 设置客户端标识。 + +- `func (c *Config) SetGroupID(groupID string) error` + + 设置订阅组 ID。 + +- `func (c *Config) SetWriteWait(writeWait time.Duration) error` + + 设置发送消息等待时间。 + +- `func (c *Config) SetMessageTimeout(timeout time.Duration) error` + + 设置消息超时时间。 + +- `func (c *Config) SetErrorHandler(f func(consumer *Consumer, err error))` + + 设置错误处理方法。 + +- `func (c *Config) SetCloseHandler(f func())` + + 设置关闭处理方法。 + +### 订阅相关 API + +- `func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error)` + + 创建消费者。 + +- `func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error` + + 订阅单个主题。 + +- `func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error` + + 订阅主题。 + +- `func (c *Consumer) Poll(timeoutMs int) tmq.Event` + + 轮询消息。 + +- `func (c *Consumer) Commit() ([]tmq.TopicPartition, error)` + + 提交消息。 + +- `func (c *Consumer) Close() error` + + 关闭连接。 + +示例代码:[`examples/tmqoverws/main.go`](examples/tmqoverws/main.go)。 + +## 通过 WebSocket 进行参数绑定 + +通过 websocket 方式使用 stmt。服务端需要启动 taoAdapter。 + +### 配置相关 API + +- `func NewConfig(url string, chanLength uint) *Config` + + 创建配置项,传入 websocket 地址和发送管道长度。 + +- `func (c *Config) SetCloseHandler(f func())` + + 设置关闭处理方法。 + +- `func (c *Config) SetConnectDB(db string) error` + + 设置连接 DB。 + +- `func (c *Config) SetConnectPass(pass string) error` + + 设置连接密码。 + +- `func (c *Config) SetConnectUser(user string) error` + + 设置连接用户名。 + +- `func (c *Config) SetErrorHandler(f func(connector *Connector, err error))` + + 设置错误处理函数。 + +- `func (c *Config) SetMessageTimeout(timeout time.Duration) error` + + 设置消息超时时间。 + +- `func (c *Config) SetWriteWait(writeWait time.Duration) error` + + 设置发送消息等待时间。 + +### 参数绑定相关 API + +* `func NewConnector(config *Config) (*Connector, error)` + + 创建连接。 + +* `func (c *Connector) Init() (*Stmt, error)` + + 初始化参数。 + +* `func (c *Connector) Close() error` + + 关闭连接。 + +* `func (s *Stmt) Prepare(sql string) error` + + 参数绑定预处理 SQL 语句。 + +* `func (s *Stmt) SetTableName(name string) error` + + 参数绑定设置表名。 + +* `func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error` + + 参数绑定设置标签。 + +* `func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) error` + + 参数绑定多行数据。 + +* `func (s *Stmt) AddBatch() error` + + 添加到参数绑定批处理。 + +* `func (s *Stmt) Exec() error` + + 执行参数绑定。 + +* `func (s *Stmt) GetAffectedRows() int` + + 获取参数绑定插入受影响行数。 + +* `func (s *Stmt) Close() error` + + 结束参数绑定。 + +完整参数绑定示例参见 [GitHub 示例文件](examples/stmtoverws/main.go) + ## 目录结构 ```text @@ -391,10 +633,11 @@ driver-go ├── common //通用方法以及常量 ├── errors //错误类型 ├── examples //样例 -├── taosRestful // 数据库操作标准接口(restful) +├── taosRestful // 数据库操作标准接口 (restful) ├── taosSql // 数据库操作标准接口 ├── types // 内置类型 -└── wrapper // cgo 包装器 +├── wrapper // cgo 包装器 +└── ws // websocket ``` ## 导航 diff --git a/README.md b/README.md index 61f8add..f674233 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,10 @@ v2 is not compatible with v3 version and corresponds to the TDengine version as | **driver-go version** | **TDengine version** | |-----------------------|----------------------| | v3.0.0 | 3.0.0.0+ | +| v3.0.1 | 3.0.0.0+ | +| v3.0.3 | 3.0.1.5+ | +| v3.0.4 | 3.0.2.2+ | +| v3.1.0 | 3.0.2.2+ | ## Install @@ -123,31 +127,31 @@ APIs that are worthy to have a check: Create consumer: ````go -func NewConsumer(conf *Config) (*Consumer, error) +func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) ```` -Subscribe: +Subscribe single topic: ````go -func (c *Consumer) Subscribe(topics []string) error +func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error ```` -Poll message: +Subscribe topics: ````go -func (c *Consumer) Poll(timeout time.Duration) (*Result, error) +func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error ```` -Commit message: +Poll message: ````go -func (c *Consumer) Commit(ctx context.Context, message unsafe.Pointer) error +func (c *Consumer) Poll(timeoutMs int) tmq.Event ```` -Free message: +Commit message: ````go -func (c *Consumer) FreeMessage(message unsafe.Pointer) +func (c *Consumer) Commit() ([]tmq.TopicPartition, error) ```` Unsubscribe: @@ -311,15 +315,6 @@ import ( ) ``` -Introduce - -```go -import ( - "database/sql" - _ "github.com/taosdata/driver-go/v3/taosRestful" -) -``` - The driverName of `sql.Open` is `taosRestful` The DSN format is: @@ -393,6 +388,244 @@ func main() { } ``` +## websocket implementation of the `database/sql` standard interface + +A simple use case: + +```go +package main + +import ( + "database/sql" + "fmt" + "time" + + _ "github.com/taosdata/driver-go/v3/taosWS" +) + +func main() { + var taosDSN = "root:taosdata@ws(localhost:6041)/" + taos, err := sql.Open("taosWS", taosDSN) + if err != nil { + fmt.Println("failed to connect TDengine, err:", err) + return + } + defer taos.Close() + taos.Exec("create database if not exists test") + taos.Exec("create table if not exists test.tb1 (ts timestamp, a int)") + _, err = taos.Exec("insert into test.tb1 values(now, 0)(now+1s,1)(now+2s,2)(now+3s,3)") + if err != nil { + fmt.Println("failed to insert, err:", err) + return + } + rows, err := taos.Query("select * from test.tb1") + if err != nil { + fmt.Println("failed to select from table, err:", err) + return + } + + defer rows.Close() + for rows.Next() { + var r struct { + ts time.Time + a int + } + err := rows.Scan(&r.ts, &r.a) + if err != nil { + fmt.Println("scan error:\n", err) + return + } + fmt.Println(r.ts, r.a) + } +} +``` + +### Usage of websocket + +import + +```go +import ( + "database/sql" + _ "github.com/taosdata/driver-go/v3/taosWS" +) +``` + +The driverName of `sql.Open` is `taosWS` + +The DSN format is: + +```text +database username:database password@connection-method(domain or ip:port)/[database][? parameter] +``` + +Example: + +```text +root:taosdata@ws(localhost:6041)/test?writeTimeout=10s&readTimeout=10m +``` + +Parameters: + +- `writeTimeout` The timeout to send data via websocket. +- `readTimeout` The timeout to receive response data via websocket. + +## Using tmq over websocket + +Use tmq over websocket. The server needs to start taoAdapter. + +### Configure related API + +- `func NewConfig(url string, chanLength uint) *Config` + + Create a configuration, pass in the websocket address and the length of the sending channel. + +- `func (c *Config) SetConnectUser(user string) error` + + Set username. + +- `func (c *Config) SetConnectPass(pass string) error` + + Set password. + +- `func (c *Config) SetClientID(clientID string) error` + + Set the client ID. + +- `func (c *Config) SetGroupID(groupID string) error` + + Set the subscription group ID. + +- `func (c *Config) SetWriteWait(writeWait time.Duration) error` + + Set the waiting time for sending messages. + +- `func (c *Config) SetMessageTimeout(timeout time.Duration) error` + + Set the message timeout. + +- `func (c *Config) SetErrorHandler(f func(consumer *Consumer, err error))` + + Set the error handler. + +- `func (c *Config) SetCloseHandler(f func())` + + Set the close handler. + +### Subscription related API + +- `func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error)` + + Create a consumer. + +- `func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error` + + Subscribe a topic. + +- `func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error` + + Subscribe to topics. + +- `func (c *Consumer) Poll(timeoutMs int) tmq.Event` + + Poll messages. + +- `func (c *Consumer) Commit() ([]tmq.TopicPartition, error)` + + Commit message. + +- `func (c *Consumer) Close() error` + + Close the connection. + +Example code: [`examples/tmqoverws/main.go`](examples/tmqoverws/main.go). + +## Parameter binding via WebSocket + +Use stmt via websocket. The server needs to start taoAdapter. + +### Configure related API + +- `func NewConfig(url string, chanLength uint) *Config` + + Create a configuration item, pass in the websocket address and the length of the sending pipe. + +- `func (c *Config) SetCloseHandler(f func())` + + Set close handler. + +- `func (c *Config) SetConnectDB(db string) error` + + Set connect DB. + +- `func (c *Config) SetConnectPass(pass string) error` + + Set password. + +- `func (c *Config) SetConnectUser(user string) error` + + Set username. + +- `func (c *Config) SetErrorHandler(f func(connector *Connector, err error))` + + Set error handler. + +- `func (c *Config) SetMessageTimeout(timeout time.Duration) error` + + Set the message timeout. + +- `func (c *Config) SetWriteWait(writeWait time.Duration) error` + + Set the waiting time for sending messages. + +### Parameter binding related API + +* `func NewConnector(config *Config) (*Connector, error)` + + Create a connection. + +* `func (c *Connector) Init() (*Stmt, error)` + + Initialize the parameters. + +* `func (c *Connector) Close() error` + + Close the connection. + +* `func (s *Stmt) Prepare(sql string) error` + + Parameter binding preprocessing SQL statement. + +* `func (s *Stmt) SetTableName(name string) error` + + Bind the table name parameter. + +* `func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error` + + Bind tags. + +* `func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) error` + + Parameter bind multiple rows of data. + +* `func (s *Stmt) AddBatch() error` + + Add to a parameter-bound batch. + +* `func (s *Stmt) Exec() error` + + Execute a parameter binding. + +* `func (s *Stmt) GetAffectedRows() int` + + Gets the number of affected rows inserted by the parameter binding. + +* `func (s *Stmt) Close() error` + + Closes the parameter binding. + +For a complete example of parameter binding, see [GitHub example file](examples/stmtoverws/main.go) + ## Directory structure ```text @@ -404,7 +637,8 @@ driver-go ├── taosRestful // database operation standard interface (restful) ├── taosSql // database operation standard interface ├── types // inner type -└── wrapper // cgo wrapper +├── wrapper // cgo wrapper +└── ws // websocket ``` ## Link diff --git a/af/conn.go b/af/conn.go index f87e0a5..869c7c3 100644 --- a/af/conn.go +++ b/af/conn.go @@ -63,6 +63,22 @@ func (conn *Connector) StmtExecute(sql string, params *param.Param) (res driver. } defer stmt.Close() + return conn.stmtExecute(stmt, sql, params) +} + +// StmtExecuteWithReqID Execute sql through stmt with reqID +func (conn *Connector) StmtExecuteWithReqID(sql string, params *param.Param, reqID int64) (res driver.Result, err error) { + stmt := NewStmtWithReqID(conn.taos, reqID) + if stmt == nil { + err = &errors.TaosError{Code: 0xffff, ErrStr: "failed to init stmt"} + return + } + + defer stmt.Close() + return conn.stmtExecute(stmt, sql, params) +} + +func (conn *Connector) stmtExecute(stmt *Stmt, sql string, params *param.Param) (res driver.Result, err error) { err = stmt.Prepare(sql) if err != nil { return nil, err @@ -89,7 +105,25 @@ func (conn *Connector) Exec(query string, args ...driver.Value) (driver.Result, return nil, driver.ErrBadConn } if len(args) != 0 { - prepared, err := common.InterpolateParams(query, args) + prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) + if err != nil { + return nil, err + } + query = prepared + } + asyncHandler := async.GetHandler() + defer async.PutHandler(asyncHandler) + result := conn.taosQuery(query, asyncHandler, 0) + return conn.processExecResult(result) +} + +// ExecWithReqID Execute sql with reqID +func (conn *Connector) ExecWithReqID(query string, reqID int64, args ...driver.Value) (driver.Result, error) { + if conn.taos == nil { + return nil, driver.ErrBadConn + } + if len(args) != 0 { + prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) if err != nil { return nil, err } @@ -97,7 +131,11 @@ func (conn *Connector) Exec(query string, args ...driver.Value) (driver.Result, } asyncHandler := async.GetHandler() defer async.PutHandler(asyncHandler) - result := conn.taosQuery(query, asyncHandler) + result := conn.taosQuery(query, asyncHandler, reqID) + return conn.processExecResult(result) +} + +func (conn *Connector) processExecResult(result *handler.AsyncResult) (driver.Result, error) { defer func() { if result != nil && result.Res != nil { locker.Lock() @@ -106,8 +144,7 @@ func (conn *Connector) Exec(query string, args ...driver.Value) (driver.Result, } }() res := result.Res - code := wrapper.TaosError(res) - if code != int(errors.SUCCESS) { + if code := wrapper.TaosError(res); code != int(errors.SUCCESS) { errStr := wrapper.TaosErrorStr(res) return nil, errors.NewError(code, errStr) } @@ -121,19 +158,38 @@ func (conn *Connector) Query(query string, args ...driver.Value) (driver.Rows, e return nil, driver.ErrBadConn } if len(args) != 0 { - prepared, err := common.InterpolateParams(query, args) + prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) + if err != nil { + return nil, err + } + query = prepared + } + h := async.GetHandler() + result := conn.taosQuery(query, h, 0) + return conn.processQueryResult(result, h) +} + +// QueryWithReqID Execute query sql with reqID +func (conn *Connector) QueryWithReqID(query string, reqID int64, args ...driver.Value) (driver.Rows, error) { + if conn.taos == nil { + return nil, driver.ErrBadConn + } + if len(args) != 0 { + prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) if err != nil { return nil, err } query = prepared } + h := async.GetHandler() + result := conn.taosQuery(query, h, reqID) + return conn.processQueryResult(result, h) +} - handler := async.GetHandler() - result := conn.taosQuery(query, handler) +func (conn *Connector) processQueryResult(result *handler.AsyncResult, h *handler.Handler) (driver.Rows, error) { res := result.Res - code := wrapper.TaosError(res) - if code != int(errors.SUCCESS) { - async.PutHandler(handler) + if code := wrapper.TaosError(res); code != int(errors.SUCCESS) { + async.PutHandler(h) errStr := wrapper.TaosErrorStr(res) locker.Lock() wrapper.TaosFreeResult(result.Res) @@ -143,22 +199,26 @@ func (conn *Connector) Query(query string, args ...driver.Value) (driver.Rows, e numFields := wrapper.TaosNumFields(res) rowsHeader, err := wrapper.ReadColumn(res, numFields) if err != nil { + async.PutHandler(h) return nil, err } precision := wrapper.TaosResultPrecision(res) rs := &rows{ - handler: handler, + handler: h, rowsHeader: rowsHeader, result: res, precision: precision, } return rs, nil - } -func (conn *Connector) taosQuery(sqlStr string, handler *handler.Handler) *handler.AsyncResult { +func (conn *Connector) taosQuery(sqlStr string, handler *handler.Handler, reqID int64) *handler.AsyncResult { locker.Lock() - wrapper.TaosQueryA(conn.taos, sqlStr, handler.Handler) + if reqID == 0 { + wrapper.TaosQueryA(conn.taos, sqlStr, handler.Handler) + } else { + wrapper.TaosQueryAWithReqID(conn.taos, sqlStr, handler.Handler, reqID) + } locker.Unlock() r := <-handler.Caller.QueryResult return r @@ -169,6 +229,11 @@ func (conn *Connector) InsertStmt() *insertstmt.InsertStmt { return insertstmt.NewInsertStmt(conn.taos) } +// InsertStmtWithReqID Prepare batch insert stmt with reqID +func (conn *Connector) InsertStmtWithReqID(reqID int64) *insertstmt.InsertStmt { + return insertstmt.NewInsertStmtWithReqID(conn.taos, reqID) +} + // SelectDB Execute `use db` func (conn *Connector) SelectDB(db string) error { locker.Lock() @@ -182,6 +247,7 @@ func (conn *Connector) SelectDB(db string) error { } // InfluxDBInsertLines Insert data using influxdb line format +// Deprecated func (conn *Connector) InfluxDBInsertLines(lines []string, precision string) error { locker.Lock() result := wrapper.TaosSchemalessInsert(conn.taos, lines, wrapper.InfluxDBLineProtocol, precision) @@ -201,6 +267,7 @@ func (conn *Connector) InfluxDBInsertLines(lines []string, precision string) err } // OpenTSDBInsertTelnetLines Insert data using opentsdb telnet format +// Deprecated func (conn *Connector) OpenTSDBInsertTelnetLines(lines []string) error { locker.Lock() result := wrapper.TaosSchemalessInsert(conn.taos, lines, wrapper.OpenTSDBTelnetLineProtocol, "") @@ -218,6 +285,7 @@ func (conn *Connector) OpenTSDBInsertTelnetLines(lines []string) error { } // OpenTSDBInsertJsonPayload Insert data using opentsdb json format +// Deprecated func (conn *Connector) OpenTSDBInsertJsonPayload(payload string) error { result := wrapper.TaosSchemalessInsert(conn.taos, []string{payload}, wrapper.OpenTSDBJsonFormatProtocol, "") code := wrapper.TaosError(result) @@ -233,3 +301,12 @@ func (conn *Connector) OpenTSDBInsertJsonPayload(payload string) error { locker.Unlock() return nil } + +func (conn *Connector) GetTableVGroupID(db, table string) (vgID int, err error) { + var code int + vgID, code = wrapper.TaosGetTableVgID(conn.taos, db, table) + if code != 0 { + err = errors.NewError(code, wrapper.TaosErrorStr(nil)) + } + return +} diff --git a/af/conn_test.go b/af/conn_test.go index e399f4f..fb5eeaa 100644 --- a/af/conn_test.go +++ b/af/conn_test.go @@ -796,3 +796,113 @@ func TestOpenTSDBInsertJsonPayloadWrong(t *testing.T) { return } } + +func TestConnector_StmtExecuteWithReqID(t *testing.T) { + db := testDatabase(t) + defer db.Close() + _, err := db.ExecWithReqID("create stable if not exists meters (ts timestamp, current float, voltage int, phase float) tags (location binary(64), groupId int)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + params := param2.NewParam(4) + params.AddTimestamp(time.Now(), 0). + AddFloat(10.2).AddInt(219).AddFloat(0.32) + _, err = db.StmtExecuteWithReqID("INSERT INTO d21001 USING meters TAGS ('California.SanFrancisco', 2) "+ + "VALUES (?, ?, ?, ?)", + params, + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + _, err = db.ExecWithReqID("drop stable if exists meters", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } +} + +func TestConnector_InsertStmtWithReqID(t *testing.T) { + db := testDatabase(t) + defer db.Close() + _, err := db.ExecWithReqID("create stable if not exists meters (ts timestamp, current float, voltage int, phase float) tags (location binary(64), groupId int)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + defer func() { + _, _ = db.ExecWithReqID("drop stable if exists meters", common.GetReqID()) + }() + params := []*param2.Param{ + param2.NewParam(1).AddTimestamp(time.Now(), common.PrecisionMilliSecond), + param2.NewParam(1).AddFloat(10.2), + param2.NewParam(1).AddInt(219), + param2.NewParam(1).AddFloat(0.32), + } + bindType := param2.NewColumnType(4).AddTimestamp().AddFloat().AddInt().AddFloat() + + stmt := db.InsertStmtWithReqID(common.GetReqID()) + defer stmt.Close() + stmt.Prepare("INSERT INTO d21001 USING meters TAGS ('California.SanFrancisco', 2) VALUES (?, ?, ?, ?)") + stmt.BindParam(params, bindType) + stmt.AddBatch() + err = stmt.Execute() + if err != nil { + t.Fatal(err) + } + if stmt.GetAffectedRows() != 1 { + t.Fatal("result miss") + } + +} + +func TestConnector_ExecWithReqID(t *testing.T) { + db := testDatabase(t) + defer db.Close() + _, err := db.ExecWithReqID("create stable if not exists meters (ts timestamp, current float, voltage int, phase float) tags (location binary(64), groupId int)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + defer func() { + _, _ = db.ExecWithReqID("drop stable if exists meters", common.GetReqID()) + }() + _, err = db.ExecWithReqID("INSERT INTO d21001 USING meters TAGS ('California.SanFrancisco', 2) VALUES ('2021-07-13 14:06:32.272', 10.2, 219, 0.32)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } +} + +func TestConnector_QueryWithReqID(t *testing.T) { + db := testDatabase(t) + defer db.Close() + _, err := db.ExecWithReqID("create stable if not exists meters (ts timestamp, current float, voltage int, phase float) tags (location binary(64), groupId int)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + defer func() { + _, _ = db.ExecWithReqID("drop stable if exists meters", common.GetReqID()) + }() + + _, err = db.ExecWithReqID("INSERT INTO d21001 USING meters TAGS ('California.SanFrancisco', 2) VALUES ('2021-07-13 14:06:32.272', 10.2, 219, 0.32)", + common.GetReqID()) + if err != nil { + t.Fatal(err) + } + res, err := db.QueryWithReqID("select count(*) from meters", common.GetReqID()) + if err != nil { + t.Fatal(err) + } + defer res.Close() + v := make([]driver.Value, 1) + err = res.Next(v) + if err != nil { + t.Fatal(err) + return + } + if v[0].(int64) != 1 { + t.Fatal("result is error") + } +} diff --git a/af/insertstmt/stmt.go b/af/insertstmt/stmt.go index 98d72ec..aae9795 100644 --- a/af/insertstmt/stmt.go +++ b/af/insertstmt/stmt.go @@ -22,6 +22,13 @@ func NewInsertStmt(taosConn unsafe.Pointer) *InsertStmt { return &InsertStmt{stmt: stmt} } +func NewInsertStmtWithReqID(taosConn unsafe.Pointer, reqID int64) *InsertStmt { + locker.Lock() + stmt := wrapper.TaosStmtInitWithReqID(taosConn, reqID) + locker.Unlock() + return &InsertStmt{stmt: stmt} +} + func (stmt *InsertStmt) Prepare(sql string) error { locker.Lock() code := wrapper.TaosStmtPrepare(stmt.stmt, sql) diff --git a/af/rows.go b/af/rows.go index cf9f294..a1702ae 100644 --- a/af/rows.go +++ b/af/rows.go @@ -8,6 +8,7 @@ import ( "github.com/taosdata/driver-go/v3/af/async" "github.com/taosdata/driver-go/v3/af/locker" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/handler" @@ -56,7 +57,9 @@ func (rs *rows) Next(dest []driver.Value) error { } if rs.block == nil { - rs.taosFetchBlock() + if err := rs.taosFetchBlock(); err != nil { + return err + } } if rs.blockSize == 0 { rs.block = nil @@ -65,14 +68,16 @@ func (rs *rows) Next(dest []driver.Value) error { } if rs.blockOffset >= rs.blockSize { - rs.taosFetchBlock() + if err := rs.taosFetchBlock(); err != nil { + return err + } } if rs.blockSize == 0 { rs.block = nil rs.freeResult() return io.EOF } - wrapper.ReadRow(dest, rs.block, rs.blockSize, rs.blockOffset, rs.rowsHeader.ColTypes, rs.precision) + parser.ReadRow(dest, rs.block, rs.blockSize, rs.blockOffset, rs.rowsHeader.ColTypes, rs.precision) rs.blockOffset++ return nil } @@ -111,7 +116,9 @@ func (rs *rows) freeResult() { locker.Unlock() rs.result = nil } + if rs.handler != nil { async.PutHandler(rs.handler) + rs.handler = nil } } diff --git a/af/stmt.go b/af/stmt.go index 31b0457..18d52c2 100644 --- a/af/stmt.go +++ b/af/stmt.go @@ -24,6 +24,13 @@ func NewStmt(taosConn unsafe.Pointer) *Stmt { return &Stmt{stmt: stmt} } +func NewStmtWithReqID(taosConn unsafe.Pointer, reqID int64) *Stmt { + locker.Lock() + stmt := wrapper.TaosStmtInitWithReqID(taosConn, reqID) + locker.Unlock() + return &Stmt{stmt: stmt} +} + func (s *Stmt) Prepare(sql string) error { locker.Lock() code := wrapper.TaosStmtPrepare(s.stmt, sql) diff --git a/af/tmq/config.go b/af/tmq/config.go index c277895..f037c46 100644 --- a/af/tmq/config.go +++ b/af/tmq/config.go @@ -1,64 +1,21 @@ package tmq import ( - "strconv" "unsafe" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" ) -type Config struct { - cConfig unsafe.Pointer - autoCommit bool - cb CommitHandleFunc - needGetTableName bool +type config struct { + cConfig unsafe.Pointer } -type CommitHandleFunc func(*wrapper.TMQCommitCallbackResult) - -// NewConfig New TMQ config -func NewConfig() *Config { - return &Config{cConfig: wrapper.TMQConfNew()} -} - -// SetGroupID TMQ set `group.id` -func (c *Config) SetGroupID(groupID string) error { - return c.SetConfig("group.id", groupID) -} - -// SetAutoOffsetReset TMQ set `auto.offset.reset` -func (c *Config) SetAutoOffsetReset(auto string) error { - return c.SetConfig("auto.offset.reset", auto) -} - -// SetConnectIP TMQ set `td.connect.ip` -func (c *Config) SetConnectIP(ip string) error { - return c.SetConfig("td.connect.ip", ip) -} - -// SetConnectUser TMQ set `td.connect.user` -func (c *Config) SetConnectUser(user string) error { - return c.SetConfig("td.connect.user", user) -} - -// SetConnectPass TMQ set `td.connect.pass` -func (c *Config) SetConnectPass(pass string) error { - return c.SetConfig("td.connect.pass", pass) -} - -// SetConnectPort TMQ set `td.connect.port` -func (c *Config) SetConnectPort(port string) error { - return c.SetConfig("td.connect.port", port) +func newConfig() *config { + return &config{cConfig: wrapper.TMQConfNew()} } -// SetMsgWithTableName TMQ set `msg.with.table.name` -func (c *Config) SetMsgWithTableName(b bool) error { - c.needGetTableName = b - return c.SetConfig("msg.with.table.name", strconv.FormatBool(b)) -} - -func (c *Config) SetConfig(key string, value string) error { +func (c *config) setConfig(key string, value string) error { errCode := wrapper.TMQConfSet(c.cConfig, key, value) if errCode != errors.SUCCESS { errStr := wrapper.TMQErr2Str(errCode) @@ -67,34 +24,7 @@ func (c *Config) SetConfig(key string, value string) error { return nil } -// EnableAutoCommit TMQ set `enable.auto.commit` to `true` and set auto commit callback -func (c *Config) EnableAutoCommit(f CommitHandleFunc) error { - err := c.SetConfig("enable.auto.commit", "true") - if err != nil { - return err - } - c.cb = f - c.autoCommit = true - return nil -} - -// DisableAutoCommit TMQ set `enable.auto.commit` to `false` -func (c *Config) DisableAutoCommit() error { - err := c.SetConfig("enable.auto.commit", "false") - if err != nil { - return err - } - c.cb = nil - c.autoCommit = false - return nil -} - -// EnableHeartBeat TMQ set `enable.heartbeat.background` to `true` -func (c *Config) EnableHeartBeat() error { - return c.SetConfig("enable.heartbeat.background", "true") -} - // Destroy Release TMQ config -func (c *Config) Destroy() { +func (c *config) destroy() { wrapper.TMQConfDestroy(c.cConfig) } diff --git a/af/tmq/config_test.go b/af/tmq/config_test.go index ddea011..4a0bbd2 100644 --- a/af/tmq/config_test.go +++ b/af/tmq/config_test.go @@ -7,8 +7,8 @@ import ( ) func TestConfig(t *testing.T) { - conf := NewConfig() - conf.Destroy() + conf := newConfig() + conf.destroy() } func TestList(t *testing.T) { diff --git a/af/tmq/consumer.go b/af/tmq/consumer.go index 7fba462..b575181 100644 --- a/af/tmq/consumer.go +++ b/af/tmq/consumer.go @@ -1,90 +1,76 @@ package tmq import ( - "context" - "database/sql/driver" - "time" + "errors" "unsafe" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/common/tmq" + taosError "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" - "github.com/taosdata/driver-go/v3/wrapper/cgo" -) - -var ( - ClosedError = errors.NewError(0xffff, "consumer closed") ) type Consumer struct { - conf *Config - cConsumer unsafe.Pointer - autoCommitChan chan *wrapper.TMQCommitCallbackResult - autoCommitHandle cgo.Handle - autoCommitHandleFunc CommitHandleFunc - asyncCommitChan chan *wrapper.TMQCommitCallbackResult - asyncCommitHandle cgo.Handle - exit chan struct{} + cConsumer unsafe.Pointer } // NewConsumer Create new TMQ consumer with TMQ config -func NewConsumer(conf *Config) (*Consumer, error) { - cConsumer, err := wrapper.TMQConsumerNew(conf.cConfig) +func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { + confStruct, err := configMapToConfig(conf) if err != nil { return nil, err } - asyncChan := make(chan *wrapper.TMQCommitCallbackResult, 1) - asyncHandle := cgo.NewHandle(asyncChan) - consumer := &Consumer{ - conf: conf, - cConsumer: cConsumer, - exit: make(chan struct{}), - asyncCommitChan: asyncChan, - asyncCommitHandle: asyncHandle, + defer confStruct.destroy() + cConsumer, err := wrapper.TMQConsumerNew(confStruct.cConfig) + if err != nil { + return nil, err } - if conf.autoCommit { - autoChan := make(chan *wrapper.TMQCommitCallbackResult, 1) - autoHandle := cgo.NewHandle(autoChan) - wrapper.TMQConfSetAutoCommitCB(conf.cConfig, autoHandle) - consumer.autoCommitChan = autoChan - consumer.autoCommitHandle = autoHandle - consumer.handlerCommitCallback() + consumer := &Consumer{ + cConsumer: cConsumer, } return consumer, nil } -func (c *Consumer) handlerCommitCallback() { - go func() { - for { - select { - case <-c.exit: - c.autoCommitHandle.Delete() - close(c.asyncCommitChan) - return - case d := <-c.autoCommitChan: - c.autoCommitHandleFunc(d) - wrapper.PutTMQCommitCallbackResult(d) - } +func configMapToConfig(m *tmq.ConfigMap) (*config, error) { + c := newConfig() + confCopy := m.Clone() + for k, v := range confCopy { + vv, ok := v.(string) + if !ok { + c.destroy() + return nil, errors.New("config value requires string") } - }() + err := c.setConfig(k, vv) + if err != nil { + c.destroy() + return nil, err + } + } + return c, nil } -// Subscribe TMQ consumer subscribe topics -func (c *Consumer) Subscribe(topics []string) error { +type RebalanceCb func(*Consumer, tmq.Event) error + +func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error { + return c.SubscribeTopics([]string{topic}, rebalanceCb) +} + +func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error { topicList := wrapper.TMQListNew() defer wrapper.TMQListDestroy(topicList) for _, topic := range topics { errCode := wrapper.TMQListAppend(topicList, topic) if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) - return errors.NewError(int(errCode), errStr) + return taosError.NewError(int(errCode), errStr) } } errCode := wrapper.TMQSubscribe(c.cConsumer, topicList) if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) - return errors.NewError(int(errCode), errStr) + return taosError.NewError(int(errCode), errStr) } return nil } @@ -92,115 +78,129 @@ func (c *Consumer) Subscribe(topics []string) error { // Unsubscribe TMQ unsubscribe func (c *Consumer) Unsubscribe() error { errCode := wrapper.TMQUnsubscribe(c.cConsumer) - if errCode != errors.SUCCESS { + if errCode != taosError.SUCCESS { errStr := wrapper.TMQErr2Str(errCode) - return errors.NewError(int(errCode), errStr) + return taosError.NewError(int(errCode), errStr) } return nil } -type Result struct { - Type int32 - DBName string - Topic string - Message unsafe.Pointer - Meta *common.Meta - Data []*Data -} -type Data struct { - TableName string - Data [][]driver.Value -} - -//Poll consumer poll message with timeout -func (c *Consumer) Poll(timeout time.Duration) (*Result, error) { - message := wrapper.TMQConsumerPoll(c.cConsumer, timeout.Milliseconds()) +// Poll consumer poll message with timeout +func (c *Consumer) Poll(timeoutMs int) tmq.Event { + message := wrapper.TMQConsumerPoll(c.cConsumer, int64(timeoutMs)) if message == nil { - return nil, nil + return nil } topic := wrapper.TMQGetTopicName(message) db := wrapper.TMQGetDBName(message) resultType := wrapper.TMQGetResType(message) - result := &Result{ - Type: resultType, - DBName: db, - Topic: topic, - Message: message, - } switch resultType { + case common.TMQ_RES_DATA: + result := &tmq.DataMessage{} + result.SetDbName(db) + result.SetTopic(topic) + data, err := c.getData(message) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + result.SetData(data) + wrapper.TaosFreeResult(message) + return result case common.TMQ_RES_TABLE_META: - var meta common.Meta - p := wrapper.TMQGetJsonMeta(message) - if p != nil { - data := wrapper.ParseJsonMeta(p) - wrapper.TMQFreeJsonMeta(p) - err := jsoniter.Unmarshal(data, &meta) - if err != nil { - return nil, err - } - result.Meta = &meta + result := &tmq.MetaMessage{} + result.SetDbName(db) + result.SetTopic(topic) + meta, err := c.getMeta(message) + if err != nil { + return tmq.NewTMQErrorWithErr(err) } - return result, nil - case common.TMQ_RES_DATA: - for { - blockSize, errCode, block := wrapper.TaosFetchRawBlock(message) - if errCode != int(errors.SUCCESS) { - errStr := wrapper.TaosErrorStr(message) - err := errors.NewError(errCode, errStr) - return nil, err - } - if blockSize == 0 { - break - } - r := &Data{} - if c.conf.needGetTableName { - r.TableName = wrapper.TMQGetTableName(message) - } - fileCount := wrapper.TaosNumFields(message) - rh, err := wrapper.ReadColumn(message, fileCount) - if err != nil { - return nil, err - } - precision := wrapper.TaosResultPrecision(message) - r.Data = append(r.Data, wrapper.ReadBlock(block, blockSize, rh.ColTypes, precision)...) - result.Data = append(result.Data, r) + result.SetMeta(meta) + wrapper.TaosFreeResult(message) + return result + case common.TMQ_RES_METADATA: + result := &tmq.MetaDataMessage{} + result.SetDbName(db) + result.SetTopic(topic) + data, err := c.getData(message) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + meta, err := c.getMeta(message) + if err != nil { + return tmq.NewTMQErrorWithErr(err) } - return result, nil + result.SetMetaData(&tmq.MetaData{ + Meta: meta, + Data: data, + }) + wrapper.TaosFreeResult(message) + return result default: - return nil, errors.NewError(0xfffff, "invalid tmq message type") + return tmq.NewTMQError(0xfffff, "invalid tmq message type") } } -// FreeMessage Release message after commit -func (c *Consumer) FreeMessage(message unsafe.Pointer) { - wrapper.TaosFreeResult(message) +func (c *Consumer) getMeta(message unsafe.Pointer) (*tmq.Meta, error) { + var meta tmq.Meta + p := wrapper.TMQGetJsonMeta(message) + if p != nil { + data := wrapper.ParseJsonMeta(p) + wrapper.TMQFreeJsonMeta(p) + err := jsoniter.Unmarshal(data, &meta) + if err != nil { + return nil, err + } + return &meta, nil + } + return &meta, nil } -//Commit commit message -func (c *Consumer) Commit(ctx context.Context, message unsafe.Pointer) error { - wrapper.TMQCommitAsync(c.cConsumer, message, c.asyncCommitHandle) +func (c *Consumer) getData(message unsafe.Pointer) ([]*tmq.Data, error) { + var tmqData []*tmq.Data for { - select { - case <-c.exit: - c.asyncCommitHandle.Delete() - close(c.asyncCommitChan) - return ClosedError - case <-ctx.Done(): - return ctx.Err() - case d := <-c.asyncCommitChan: - return d.GetError() + blockSize, errCode, block := wrapper.TaosFetchRawBlock(message) + if errCode != int(taosError.SUCCESS) { + errStr := wrapper.TaosErrorStr(message) + err := taosError.NewError(errCode, errStr) + return nil, err + } + if blockSize == 0 { + break } + tableName := wrapper.TMQGetTableName(message) + fileCount := wrapper.TaosNumFields(message) + rh, err := wrapper.ReadColumn(message, fileCount) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(message) + tmqData = append(tmqData, &tmq.Data{ + TableName: tableName, + Data: parser.ReadBlock(block, blockSize, rh.ColTypes, precision), + }) + } + return tmqData, nil +} + +func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { + return c.doCommit(nil) +} + +func (c *Consumer) doCommit(message unsafe.Pointer) ([]tmq.TopicPartition, error) { + errCode := wrapper.TMQCommitSync(c.cConsumer, message) + if errCode != taosError.SUCCESS { + errStr := wrapper.TMQErr2Str(errCode) + return nil, taosError.NewError(int(errCode), errStr) } + return nil, nil } // Close release consumer func (c *Consumer) Close() error { - defer c.autoCommitHandle.Delete() errCode := wrapper.TMQConsumerClose(c.cConsumer) if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) - return errors.NewError(int(errCode), errStr) + return taosError.NewError(int(errCode), errStr) } - close(c.exit) return nil } diff --git a/af/tmq/consumer_test.go b/af/tmq/consumer_test.go index 93715dd..bef8aaf 100644 --- a/af/tmq/consumer_test.go +++ b/af/tmq/consumer_test.go @@ -1,12 +1,12 @@ package tmq import ( - "context" "fmt" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common/tmq" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" ) @@ -30,7 +30,7 @@ func TestTmq(t *testing.T) { wrapper.TaosFreeResult(result) }() - result := wrapper.TaosQuery(conn, "create database if not exists af_test_tmq vgroups 2") + result := wrapper.TaosQuery(conn, "create database if not exists af_test_tmq vgroups 2 WAL_RETENTION_PERIOD 86400") code := wrapper.TaosError(result) if code != 0 { errStr := wrapper.TaosErrorStr(result) @@ -135,68 +135,61 @@ func TestTmq(t *testing.T) { } wrapper.TaosFreeResult(result) - config := NewConfig() - defer config.Destroy() - err = config.SetGroupID("test") assert.NoError(t, err) - err = config.SetAutoOffsetReset("earliest") - assert.NoError(t, err) - err = config.SetConnectIP("127.0.0.1") - assert.NoError(t, err) - err = config.SetConnectUser("root") - assert.NoError(t, err) - err = config.SetConnectPass("taosdata") - assert.NoError(t, err) - err = config.SetConnectPort("6030") - assert.NoError(t, err) - err = config.SetMsgWithTableName(true) - assert.NoError(t, err) - err = config.EnableAutoCommit(func(result *wrapper.TMQCommitCallbackResult) { - if result.ErrCode != 0 { - errStr := wrapper.TMQErr2Str(result.ErrCode) - err := errors.NewError(int(result.ErrCode), errStr) - t.Error(err) - return - } + consumer, err := NewConsumer(&tmq.ConfigMap{ + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_c", + "enable.auto.commit": "false", + "experimental.snapshot.enable": "true", + "msg.with.table.name": "true", }) - assert.NoError(t, err) - consumer, err := NewConsumer(config) if err != nil { t.Error(err) return } - err = consumer.Subscribe([]string{"test_tmq_common"}) + err = consumer.Subscribe("test_tmq_common", nil) if err != nil { t.Error(err) return } - message, err := consumer.Poll(500 * time.Millisecond) - if err != nil { - t.Error(err) - return + for i := 0; i < 5; i++ { + ev := consumer.Poll(500) + switch e := ev.(type) { + case *tmq.DataMessage: + row1 := e.Value().([]*tmq.Data)[0].Data[0] + assert.Equal(t, "af_test_tmq", e.DBName()) + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(2), row1[2].(int8)) + assert.Equal(t, int16(3), row1[3].(int16)) + assert.Equal(t, int32(4), row1[4].(int32)) + assert.Equal(t, int64(5), row1[5].(int64)) + assert.Equal(t, uint8(6), row1[6].(uint8)) + assert.Equal(t, uint16(7), row1[7].(uint16)) + assert.Equal(t, uint32(8), row1[8].(uint32)) + assert.Equal(t, uint64(9), row1[9].(uint64)) + assert.Equal(t, float32(10), row1[10].(float32)) + assert.Equal(t, float64(11), row1[11].(float64)) + assert.Equal(t, "1", row1[12].(string)) + assert.Equal(t, "2", row1[13].(string)) + _, err = consumer.Commit() + assert.NoError(t, err) + err = consumer.Unsubscribe() + assert.NoError(t, err) + err = consumer.Close() + assert.NoError(t, err) + return + case tmq.Error: + t.Error(e) + return + default: + t.Error("unexpected", e) + return + } } - - row1 := message.Data[0].Data[0] - assert.Equal(t, "af_test_tmq", message.DBName) - assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) - assert.Equal(t, true, row1[1].(bool)) - assert.Equal(t, int8(2), row1[2].(int8)) - assert.Equal(t, int16(3), row1[3].(int16)) - assert.Equal(t, int32(4), row1[4].(int32)) - assert.Equal(t, int64(5), row1[5].(int64)) - assert.Equal(t, uint8(6), row1[6].(uint8)) - assert.Equal(t, uint16(7), row1[7].(uint16)) - assert.Equal(t, uint32(8), row1[8].(uint32)) - assert.Equal(t, uint64(9), row1[9].(uint64)) - assert.Equal(t, float32(10), row1[10].(float32)) - assert.Equal(t, float64(11), row1[11].(float64)) - assert.Equal(t, "1", row1[12].(string)) - assert.Equal(t, "2", row1[13].(string)) - err = consumer.Commit(context.Background(), nil) - consumer.FreeMessage(message.Message) - assert.NoError(t, err) - err = consumer.Unsubscribe() - assert.NoError(t, err) - err = consumer.Close() - assert.NoError(t, err) } diff --git a/bench/driver/compare.go b/bench/driver/compare.go index 6d70f1a..39d4874 100644 --- a/bench/driver/compare.go +++ b/bench/driver/compare.go @@ -72,7 +72,6 @@ func testQueryC() { result, err := cdb.Query(`select * from benchmark_driver.alltype_query limit 3000`) if err != nil { panic(err) - return } for result.Next() { var ( @@ -109,11 +108,10 @@ func testQueryC() { ) if err != nil { panic(err) - return } } } - delta := time.Now().Sub(s).Nanoseconds() + delta := time.Since(s).Nanoseconds() fmt.Println("cgo query", float64(delta)/1000) } @@ -139,7 +137,6 @@ func testQueryRestful() { result, err := restfulDB.Query(`select * from benchmark_driver.alltype_query limit 3000`) if err != nil { panic(err) - return } for result.Next() { err := result.Scan( @@ -160,11 +157,10 @@ func testQueryRestful() { ) if err != nil { panic(err) - return } } } - delta := time.Now().Sub(s).Nanoseconds() + delta := time.Since(s).Nanoseconds() fmt.Println("restful query", float64(delta)/1000) } @@ -220,10 +216,9 @@ func testCGO() { _, err := cdb.Exec(dataC[i]) if err != nil { panic(err) - return } } - delta := time.Now().Sub(s).Nanoseconds() + delta := time.Since(s).Nanoseconds() fmt.Println("cgo", float64(delta)/50000) } @@ -233,9 +228,8 @@ func testRestful() { _, err := restfulDB.Exec(dataRestful[i]) if err != nil { panic(err) - return } } - delta := time.Now().Sub(s).Nanoseconds() + delta := time.Since(s).Nanoseconds() fmt.Println("restful", float64(delta)/50000) } diff --git a/bench/query/main.go b/bench/query/main.go index 585e22c..d316359 100644 --- a/bench/query/main.go +++ b/bench/query/main.go @@ -14,7 +14,6 @@ var ( password = "taosdata" host = "" port = 6030 - dbName = "test_taos_sql" dataSourceName = fmt.Sprintf("%s:%s@/tcp(%s:%d)/%s?interpolateParams=true", user, password, host, port, "") ) @@ -22,43 +21,35 @@ func main() { db, err := sql.Open(driverName, dataSourceName) if err != nil { panic(err) - return } defer db.Close() _, err = db.Exec("create database if not exists test_json") if err != nil { panic(err) - return } _, err = db.Exec("drop table if exists test_json.tjson") if err != nil { panic(err) - return } _, err = db.Exec("create stable if not exists test_json.tjson(ts timestamp,v int )tags(t json)") if err != nil { panic(err) - return } _, err = db.Exec(`insert into test_json.tj_1 using test_json.tjson tags('{"a":1,"b":"b"}')values (now,1)`) if err != nil { panic(err) - return } _, err = db.Exec(`insert into test_json.tj_2 using test_json.tjson tags('{"a":1,"c":"c"}')values (now,1)`) if err != nil { panic(err) - return } _, err = db.Exec(`insert into test_json.tj_3 using test_json.tjson tags('null')values (now,1)`) if err != nil { panic(err) - return } rows, err := db.Query("select t from test_json.tjson") if err != nil { panic(err) - return } counter := 0 for rows.Next() { @@ -66,7 +57,6 @@ func main() { err := rows.Scan(&info) if err != nil { panic(err) - return } if info != nil && !json.Valid(info) { fmt.Println("invalid json ", string(info)) diff --git a/bench/standard/executor/test.go b/bench/standard/executor/test.go index bf8d0de..68c80f5 100644 --- a/bench/standard/executor/test.go +++ b/bench/standard/executor/test.go @@ -182,13 +182,13 @@ func (t *TDTest) BenchmarkWriteSingleCommon(count int) { if err != nil { panic(err) } - fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() sqls := make([]string, count) for i := 0; i < count; i++ { sqls[i] = fmt.Sprintf("insert into wsc values(%d,true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now+int64(i)) } - fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() for i := 0; i < count; i++ { _, err = db.Exec(sqls[i]) @@ -196,7 +196,7 @@ func (t *TDTest) BenchmarkWriteSingleCommon(count int) { panic(err) } } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s :execute count: %d, execute cost: %d ns, average cost: %f ns\n", prefix, count, cost.Nanoseconds(), float64(cost.Nanoseconds())/float64(count)) } @@ -214,13 +214,13 @@ func (t *TDTest) BenchmarkWriteSingleJson(count int) { if err != nil { panic(err) } - fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() sqls := make([]string, count) for i := 0; i < count; i++ { sqls[i] = fmt.Sprintf("insert into wsj values(%d,true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now+int64(i)) } - fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() for i := 0; i < count; i++ { _, err = db.Exec(sqls[i]) @@ -228,7 +228,7 @@ func (t *TDTest) BenchmarkWriteSingleJson(count int) { panic(err) } } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s :execute count: %d, execute cost: %d ns, average cost: %f ns\n", prefix, count, cost.Nanoseconds(), float64(cost.Nanoseconds())/float64(count)) } @@ -246,7 +246,7 @@ func (t *TDTest) BenchmarkWriteBatchCommon(count, batch int) { if err != nil { panic(err) } - fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() sqls := make([]string, count) b := &bytes.Buffer{} @@ -260,7 +260,7 @@ func (t *TDTest) BenchmarkWriteBatchCommon(count, batch int) { sqls[i] = b.String() b.Reset() } - fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() for i := 0; i < count; i++ { _, err = db.Exec(sqls[i]) @@ -268,7 +268,7 @@ func (t *TDTest) BenchmarkWriteBatchCommon(count, batch int) { panic(err) } } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s :execute count: %d, batch: %d, total record: %d, execute cost: %d ns, average count cost: %f ns,average record cost %f\n", prefix, count, @@ -293,7 +293,7 @@ func (t *TDTest) BenchmarkWriteBatchJson(count, batch int) { if err != nil { panic(err) } - fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() sqls := make([]string, count) b := &bytes.Buffer{} @@ -307,7 +307,7 @@ func (t *TDTest) BenchmarkWriteBatchJson(count, batch int) { sqls[i] = b.String() b.Reset() } - fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() for i := 0; i < count; i++ { _, err = db.Exec(sqls[i]) @@ -315,7 +315,7 @@ func (t *TDTest) BenchmarkWriteBatchJson(count, batch int) { panic(err) } } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s :execute count: %d, batch: %d, total record: %d, execute cost: %d ns, average count cost: %f ns,average record cost %f\n", prefix, count, @@ -360,6 +360,9 @@ func (t *TDTest) PrepareRead(count, batch int) (tableName string) { "c12 binary(20)," + "c13 nchar(20)" + ")tags(info json)") + if err != nil { + panic(err) + } prefix := t.DriverName + ": PrepareRead" s := time.Now() @@ -369,7 +372,7 @@ func (t *TDTest) PrepareRead(count, batch int) (tableName string) { panic(err) } - fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : create table cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() sqls := make([]string, count) b := &bytes.Buffer{} @@ -383,7 +386,7 @@ func (t *TDTest) PrepareRead(count, batch int) (tableName string) { sqls[i] = b.String() b.Reset() } - fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Now().Sub(s).Nanoseconds()) + fmt.Printf("%s : prepare sql cost: %d ns\n", prefix, time.Since(s).Nanoseconds()) s = time.Now() for i := 0; i < count; i++ { _, err = db.Exec(sqls[i]) @@ -391,7 +394,7 @@ func (t *TDTest) PrepareRead(count, batch int) (tableName string) { panic(err) } } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s :execute count: %d, batch: %d, total record: %d, execute cost: %d ns, average count cost: %f ns,average record cost %f\n", prefix, count, @@ -414,7 +417,7 @@ func (t *TDTest) BenchmarkRead(sqlStr string) { if err != nil { panic(err) } - fmt.Printf("%s : query: %d ns\n", prefix, time.Now().Sub(s)) + fmt.Printf("%s : query: %d ns\n", prefix, time.Since(s)) tt, err := rows.ColumnTypes() if err != nil { log.Fatalf("ColumnTypes: %v", err) @@ -437,7 +440,7 @@ func (t *TDTest) BenchmarkRead(sqlStr string) { count += 1 rows.Scan(values...) } - cost := time.Now().Sub(s) + cost := time.Since(s) fmt.Printf("%s : result count: %d, execute cost: %d ns, average count cost: %f ns\n", prefix, count, diff --git a/bench/standard/native/exec/write.go b/bench/standard/native/exec/write.go index 070101e..d85b0b7 100644 --- a/bench/standard/native/exec/write.go +++ b/bench/standard/native/exec/write.go @@ -15,7 +15,7 @@ func main() { s := time.Now() singleCount := 1000 test.BenchmarkWriteSingleCommon(singleCount) - writeSingleCommonCost := time.Now().Sub(s) + writeSingleCommonCost := time.Since(s) fmt.Printf("write single common, count: %d,cost: %d ns,average: %f ns\n", singleCount, writeSingleCommonCost.Nanoseconds(), float64(writeSingleCommonCost.Nanoseconds())/float64(singleCount)) test.Clean() @@ -24,7 +24,7 @@ func main() { batch := 100 s = time.Now() test.BenchmarkWriteBatchJson(batchCount, batch) - writeBatchCost := time.Now().Sub(s) + writeBatchCost := time.Since(s) fmt.Printf("write batch common, count: %d,cost: %d ns,average: %f ns\n", batchCount, writeBatchCost.Nanoseconds(), float64(writeBatchCost.Nanoseconds())/float64(batch*batchCount)) } @@ -33,7 +33,7 @@ func main() { s := time.Now() singleCount := 1000 test.BenchmarkWriteSingleJson(singleCount) - writeSingleCommonCost := time.Now().Sub(s) + writeSingleCommonCost := time.Since(s) fmt.Printf("write single json, count: %d,cost: %d ns,average: %f ns\n", singleCount, writeSingleCommonCost.Nanoseconds(), float64(writeSingleCommonCost.Nanoseconds())/float64(singleCount)) test.Clean() @@ -42,7 +42,7 @@ func main() { batch := 100 s = time.Now() test.BenchmarkWriteBatchJson(batchCount, batch) - writeBatchCost := time.Now().Sub(s) + writeBatchCost := time.Since(s) fmt.Printf("write batch json, count: %d,cost: %d ns,average: %f ns\n", batchCount, writeBatchCost.Nanoseconds(), float64(writeBatchCost.Nanoseconds())/float64(batch*batchCount)) } } diff --git a/bench/standard/restful/exec/write.go b/bench/standard/restful/exec/write.go index 8b579e6..7d9486d 100644 --- a/bench/standard/restful/exec/write.go +++ b/bench/standard/restful/exec/write.go @@ -15,7 +15,7 @@ func main() { s := time.Now() singleCount := 1000 test.BenchmarkWriteSingleCommon(singleCount) - writeSingleCommonCost := time.Now().Sub(s) + writeSingleCommonCost := time.Since(s) fmt.Printf("write single common, count: %d,cost: %d ns,average: %f ns\n", singleCount, writeSingleCommonCost.Nanoseconds(), float64(writeSingleCommonCost.Nanoseconds())/float64(singleCount)) test.Clean() @@ -24,7 +24,7 @@ func main() { batch := 100 s = time.Now() test.BenchmarkWriteBatchJson(batchCount, batch) - writeBatchCost := time.Now().Sub(s) + writeBatchCost := time.Since(s) fmt.Printf("write batch common, count: %d,cost: %d ns,average: %f ns\n", batchCount, writeBatchCost.Nanoseconds(), float64(writeBatchCost.Nanoseconds())/float64(batch*batchCount)) } @@ -33,7 +33,7 @@ func main() { s := time.Now() singleCount := 1000 test.BenchmarkWriteSingleJson(singleCount) - writeSingleCommonCost := time.Now().Sub(s) + writeSingleCommonCost := time.Since(s) fmt.Printf("write single json, count: %d,cost: %d ns,average: %f ns\n", singleCount, writeSingleCommonCost.Nanoseconds(), float64(writeSingleCommonCost.Nanoseconds())/float64(singleCount)) test.Clean() @@ -42,7 +42,7 @@ func main() { batch := 100 s = time.Now() test.BenchmarkWriteBatchJson(batchCount, batch) - writeBatchCost := time.Now().Sub(s) + writeBatchCost := time.Since(s) fmt.Printf("write batch json, count: %d,cost: %d ns,average: %f ns\n", batchCount, writeBatchCost.Nanoseconds(), float64(writeBatchCost.Nanoseconds())/float64(batch*batchCount)) } } diff --git a/bench/stmt/insert/main.go b/bench/stmt/insert/main.go index 9c8f604..72cf0fc 100644 --- a/bench/stmt/insert/main.go +++ b/bench/stmt/insert/main.go @@ -15,7 +15,11 @@ import ( ) func main() { - go http.ListenAndServe(":6060", nil) + go func() { + if err := http.ListenAndServe(":6060", nil); err != nil { + panic(err) + } + }() conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { panic(err) diff --git a/bench/stmt/query/main.go b/bench/stmt/query/main.go index d47f118..78c7ea5 100644 --- a/bench/stmt/query/main.go +++ b/bench/stmt/query/main.go @@ -9,12 +9,17 @@ import ( _ "net/http/pprof" "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" ) func main() { - go http.ListenAndServe(":6060", nil) + go func() { + if err := http.ListenAndServe(":6060", nil); err != nil { + panic(err) + } + }() conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { panic(err) @@ -110,7 +115,7 @@ func StmtQuery(conn unsafe.Pointer, sql string, params *param.Param) (rows [][]d if blockSize == 0 { break } - d := wrapper.ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) + d := parser.ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) data = append(data, d...) } wrapper.TaosFreeResult(res) diff --git a/bench/test/tmq/multiinsert.go b/bench/test/tmq/multiinsert.go index 34ad885..547dd4b 100644 --- a/bench/test/tmq/multiinsert.go +++ b/bench/test/tmq/multiinsert.go @@ -7,17 +7,21 @@ import ( _ "net/http/pprof" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/cgo" ) func main() { - go http.ListenAndServe(":6060", nil) + go func() { + if err := http.ListenAndServe(":6060", nil); err != nil { + panic(err) + } + }() conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { panic(err) - return } result := wrapper.TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2") @@ -26,7 +30,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -36,7 +39,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -46,7 +48,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -56,7 +57,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -66,7 +66,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -77,7 +76,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) go func() { @@ -88,7 +86,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) } @@ -103,11 +100,8 @@ func main() { h := cgo.NewHandle(c) wrapper.TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - wrapper.PutTMQCommitCallbackResult(r) - } + for r := range c { + wrapper.PutTMQCommitCallbackResult(r) } }() tmq, err := wrapper.TMQConsumerNew(conf) @@ -124,7 +118,6 @@ func main() { if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) panic(errors.NewError(int(errCode), errStr)) - return } totalCount := 0 c2 := make(chan *wrapper.TMQCommitCallbackResult, 1) @@ -141,7 +134,6 @@ func main() { err := errors.NewError(errCode, errStr) wrapper.TaosFreeResult(message) panic(err) - return } if blockSize == 0 { break @@ -152,11 +144,10 @@ func main() { rh, err := wrapper.ReadColumn(message, filedCount) if err != nil { panic(err) - return } precision := wrapper.TaosResultPrecision(message) totalCount += blockSize - data := wrapper.ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) fmt.Println(data) } wrapper.TaosFreeResult(message) @@ -171,7 +162,6 @@ func main() { case <-timer.C: timer.Stop() panic("wait tmq commit callback timeout") - return } } table = table[:0] diff --git a/bench/tmq/main.go b/bench/tmq/main.go index 23035da..79faed0 100644 --- a/bench/tmq/main.go +++ b/bench/tmq/main.go @@ -5,6 +5,7 @@ import ( "log" "time" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/cgo" @@ -14,7 +15,6 @@ func main() { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { panic(err) - return } result := wrapper.TaosQuery(conn, "create database if not exists abc1 vgroups 2") @@ -32,7 +32,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -42,7 +41,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -52,7 +50,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -62,7 +59,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -72,7 +68,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) @@ -83,7 +78,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } wrapper.TaosFreeResult(result) go func() { @@ -97,7 +91,6 @@ func main() { errStr := wrapper.TaosErrorStr(result) wrapper.TaosFreeResult(result) panic(errors.TaosError{Code: int32(code), ErrStr: errStr}) - return } log.Println("start free result") wrapper.TaosFreeResult(result) @@ -128,7 +121,6 @@ func main() { if errCode != 0 { errStr := wrapper.TMQErr2Str(errCode) panic(errors.NewError(int(errCode), errStr)) - return } c2 := make(chan *wrapper.TMQCommitCallbackResult, 1) h2 := cgo.NewHandle(c2) @@ -140,7 +132,6 @@ func main() { rh, err := wrapper.ReadColumn(message, fileCount) if err != nil { panic(err) - return } precision := wrapper.TaosResultPrecision(message) for { @@ -150,12 +141,11 @@ func main() { err := errors.NewError(errCode, errStr) wrapper.TaosFreeResult(message) panic(err) - return } if blockSize == 0 { break } - data := wrapper.ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) fmt.Println(data) } wrapper.TaosFreeResult(message) @@ -173,7 +163,6 @@ func main() { case <-timer.C: timer.Stop() panic("wait tmq commit callback timeout") - return } } } diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 0000000..5fb366b --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,32 @@ +# Benchmark for TDengine-go-driver + +## Test tool + +We use [hyperfine](https://github.com/sharkdp/hyperfine) to test TDengin-go-driver + +## Test case + +- insert +- batch insert +- query +- average + +## Usage + +```shell +sh run_bench.sh ${BENCHMARK_TIMES} ${BATCH_TABLES} ${BATCH_ROWS} +``` + +- BENCHMARK_TIMES: ${BENCHMARK_TIMES} identifies how many tests [Hyperfine](https://github.com/sharkdp/hyperfine) will + perform. +- BATCH_TABLES: ${BENCHMARK_TIMES} identifies how many sub-tables will be used in batch insert testing case. In this + benchmark, there are 10000 sub-tables in each super table. So this value should not greater than 10000. +- BATCH_ROWS: ${BATCH_ROWS} identifies how many rows will be inserted into each sub-table in batch insert case. + The maximum SQL length in TDengine is 1M. Therefore, if this parameter is too large, the benchmark will fail. In this + benchmark, this value should not greater than 5000. + +example: + +```shell +sh run_bench.sh 10 100 1000 +``` diff --git a/benchmark/benchmark.go b/benchmark/benchmark.go new file mode 100644 index 0000000..5d69e79 --- /dev/null +++ b/benchmark/benchmark.go @@ -0,0 +1,266 @@ +package main + +import ( + "bytes" + "context" + "database/sql" + "flag" + "fmt" + "log" + "math/rand" + "time" + + _ "github.com/taosdata/driver-go/v3/taosSql" +) + +const ( + taosDb = "root:taosdata@tcp(localhost:6030)/" + benchmarkDb = "benchmark" + insertCmd = "insert" + batchInsertCmd = "batch" + queryCmd = "query" + avgCmd = "avg" + normalType = "normal" + jsonType = "json" + stb = "stb" + jtb = "jtb" + queryStb = "select * from stb" + queryJtb = "select ts, bl, i8, i16, i32, i64, u8, u16, u32, u64, f32, d64, bnr, nchr, jtag->\"k0\", jtag->\"k1\", jtag->\"k2\", jtag->\"k3\" from jtb;" + avgStbSql = "select avg(d64) from stb" + avgJtbSql = "select avg(d64) from jtb" + maxTableCnt = 10000 +) + +func main() { + ctx := context.Background() + cmd := flag.String("s", "connect", "Benchmark stage, \"connect\",\"insert\",\"query\",\"avg\",\"batch\",\"clean\",default \"connect\\\"") + types := flag.String("t", "normal", "Benchmark data type, table with\"json\" tag,table with \"normal\" column type,default \"normal\"") + tableCount := flag.Int("b", 1, "number of target tables,only for insert.Default 1 tables") + numOfRow := flag.Int("r", 1, "number of record per table,only for insert.Default 1 records") + times := flag.Int("n", 1, "number of times to run.Default 1 time.") + debug := flag.Bool("debug", false, "debug model") + flag.Parse() + + if *debug { + log.Printf("[debug] benchmark:\n cmd-[%s]\n types-[%s]\n table count-[%d]\n num of row-[%d]\n execute times-[%d]\n", + *cmd, *types, *tableCount, *numOfRow, *times) + } + + tableCnt := *tableCount + if tableCnt > maxTableCnt { + tableCnt = maxTableCnt + } + + b, err := newBench(taosDb) + panicIf("init connection ", err) + defer b.close() + + useCmd := fmt.Sprintf("use %s", benchmarkDb) + _, err = b.taos.Exec(useCmd) + panicIf(useCmd, err) + + switch *cmd { + case insertCmd: + b.insert(ctx, *types, tableCnt) + case batchInsertCmd: + b.batchInsert(ctx, *types, tableCnt, *numOfRow) + case queryCmd: + b.query(ctx, *types, *times) + case avgCmd: + b.average(ctx, *types, *times) + } +} + +type bench struct { + taos *sql.DB +} + +func newBench(dbUrl string) (*bench, error) { + taos, err := sql.Open("taosSql", dbUrl) + return &bench{taos: taos}, err +} + +func (b *bench) close() { + _ = b.taos.Close() +} + +func (b *bench) insert(ctx context.Context, types string, tableCnt int) { + table := stb + if types == jsonType { + table = jtb + } + begin := time.Now().UnixNano() / int64(time.Millisecond) + + for i := 0; i < tableCnt; i++ { + _, err := b.taos.ExecContext(ctx, + fmt.Sprintf( + "insert into %s_%d values(%d, true, -1, -2, -3, -4, 1, 2, 3, 4, 3.1415, 3.14159265358979, 'bnr_col_1', 'ncr_col_1')", + table, + i, + begin)) + panicIf("single insert", err) + } +} + +func (b *bench) batchInsert(ctx context.Context, types string, tableCnt, numOfRows int) { + table := stb + if types == jsonType { + table = jtb + } + begin := time.Now().UnixNano() / int64(time.Millisecond) + + for i := 0; i < tableCnt; i++ { + tableName := fmt.Sprintf("%s_%d", table, i) + batchSql := batchInsertSql(begin, tableName, numOfRows) + _, err := b.taos.ExecContext(ctx, batchSql) + panicIf("batch insert", err) + } +} + +func (b *bench) query(ctx context.Context, types string, times int) { + if types == normalType { + for i := 0; i < times; i++ { + rs, err := b.taos.QueryContext(ctx, queryStb) + panicIf("query normal", err) + readStbRow(rs) + } + return + } + + if types == jsonType { + for i := 0; i < times; i++ { + rs, err := b.taos.QueryContext(ctx, queryJtb) + panicIf("query json", err) + readJtbRow(rs) + } + } +} + +func (b *bench) average(ctx context.Context, types string, times int) { + query := avgStbSql + if types == jsonType { + query = avgJtbSql + } + + for i := 0; i < times; i++ { + rs, err := b.taos.QueryContext(ctx, query) + panicIf("average", err) + + for rs.Next() { + var avg float64 + err = rs.Scan(&avg) + panicIf("scan average data", err) + } + } +} + +func readStbRow(rs *sql.Rows) { + defer func() { _ = rs.Close() }() + + for rs.Next() { + var ( + ts time.Time + bl bool + i8 int8 + i16 int16 + i32 int32 + i64 int64 + u8 uint8 + u16 uint16 + u32 uint32 + u64 uint64 + f32 float32 + d64 float64 + bnr string + nchar string + t0 bool + t1 uint8 + t2 uint16 + t3 uint32 + t4 uint64 + t5 int8 + t6 int16 + t7 int32 + t8 int64 + t9 float32 + t10 float64 + t11 string + t12 string + ) + err := rs.Scan(&ts, &bl, &i8, &i16, &i32, &i64, &u8, &u16, &u32, &u64, &f32, &d64, &bnr, &nchar, &t0, &t1, &t2, + &t3, &t4, &t5, &t6, &t7, &t8, &t9, &t10, &t11, &t12) + panicIf("read row", err) + } +} + +func readJtbRow(rs *sql.Rows) { + defer func() { _ = rs.Close() }() + + for rs.Next() { + var ( + ts time.Time + bl bool + i8 int8 + i16 int16 + i32 int32 + i64 int64 + u8 uint8 + u16 uint16 + u32 uint32 + u64 uint64 + f32 float32 + d64 float64 + bnr string + nchar string + k0 string + k1 string + k2 string + k3 string + ) + err := rs.Scan(&ts, &bl, &i8, &i16, &i32, &i64, &u8, &u16, &u32, &u64, &f32, &d64, &bnr, &nchar, &k0, &k1, &k2, &k3) + panicIf("read row", err) + } +} + +func batchInsertSql(begin int64, table string, numOfRows int) string { + var buffer bytes.Buffer + buffer.WriteString(fmt.Sprintf("insert into %s values ", table)) + + for i := 0; i < numOfRows; i++ { + buffer.WriteString(fmt.Sprintf("(%d, %t, %d, %d, %d, %d, %d, %d, %d, %d, %.4f, %f, '%s', '%s')", + begin+int64(i), + rand.Intn(2) == 1, // bl + rand.Intn(256)-128, // i8 [-128, 127] + rand.Intn(65535)-32768, // i16 [-32768, 32767] + rand.Int31(), // i32 [-2^31, 2^31-1] + rand.Int63(), // i64 [-2^63, 2^63-1] + rand.Intn(256), // u8 + rand.Intn(65535), // u16 + rand.Uint32(), // u32 + rand.Uint64(), // u64 + rand.Float32(), // f32 + rand.Float64(), // d64 + randStr(20), // bnr + randStr(20), // nchr + )) + } + buffer.WriteString(";") + return buffer.String() +} + +const chars = "01234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randStr(n int) string { + var buf bytes.Buffer + for i := 0; i < n; i++ { + r := rand.Intn(63) + buf.WriteString(chars[r : r+1]) + } + return buf.String() +} + +func panicIf(msg string, err error) { + if err != nil { + panic(fmt.Errorf("%s %v", msg, err)) + } +} diff --git a/benchmark/data/only_create_table_with_json_tag.json b/benchmark/data/only_create_table_with_json_tag.json new file mode 100644 index 0000000..997430a --- /dev/null +++ b/benchmark/data/only_create_table_with_json_tag.json @@ -0,0 +1,123 @@ +{ + "filetype": "insert", + "cfgdir": "/etc/taos", + "host": "127.0.0.1", + "port": 6030, + "user": "root", + "password": "taosdata", + "connection_pool_size": 8, + "thread_count": 16, + "create_table_thread_count": 16, + "confirm_parameter_prompt": "no", + "insert_interval": 0, + "interlace_rows": 0, + "num_of_records_per_req": 30000, + "prepared_rand": 10000, + "chinese": "no", + "databases": [ + { + "dbinfo": { + "name": "benchmark", + "drop": "no", + "replica": 1, + "precision": "ms", + "keep": 3650, + "minRows": 100, + "maxRows": 4096, + "comp": 2 + }, + "super_tables": [ + { + "name": "jtb", + "child_table_exists": "no", + "childtable_count": 10000, + "childtable_prefix": "jtb_", + "escape_character": "yes", + "auto_create_table": "no", + "batch_create_tbl_num": 10000, + "data_source": "rand", + "insert_mode": "taosc", + "non_stop_mode": "no", + "line_protocol": "line", + "insert_rows": 0, + "childtable_limit": 10, + "childtable_offset": 100, + "interlace_rows": 0, + "insert_interval": 0, + "partial_col_num": 0, + "disorder_ratio": 0, + "disorder_range": 1000, + "timestamp_step": 10, + "start_timestamp": "2020-10-01 00:00:00.000", + "sample_format": "csv", + "sample_file": "./sample.csv", + "use_sample_ts": "no", + "tags_file": "", + "columns": [ + { + "type": "BOOL", + "name": "bl" + }, + { + "type": "TINYINT", + "name": "i8" + }, + { + "type": "SMALLINT", + "name": "i16" + }, + { + "type": "INT", + "name": "i32" + }, + { + "type": "BIGINT", + "name": "i64" + }, + { + "type": "UTINYINT", + "name": "u8" + }, + { + "type": "USMALLINT", + "name": "u16" + }, + { + "type": "UINT", + "name": "u32" + }, + { + "type": "UBIGINT", + "name": "u64" + }, + { + "type": "FLOAT", + "name": "f32" + }, + { + "type": "DOUBLE", + "name": "d64" + }, + { + "type": "VARCHAR", + "name": "bnr", + "len": 20 + }, + { + "type": "NCHAR", + "name": "nchr", + "len": 20 + } + ], + "tags": [ + { + "type": "JSON", + "len": 8, + "count": 4 + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/benchmark/data/only_create_table_with_normal_tag.json b/benchmark/data/only_create_table_with_normal_tag.json new file mode 100644 index 0000000..3f04415 --- /dev/null +++ b/benchmark/data/only_create_table_with_normal_tag.json @@ -0,0 +1,158 @@ +{ + "filetype": "insert", + "cfgdir": "/etc/taos", + "host": "127.0.0.1", + "port": 6030, + "user": "root", + "password": "taosdata", + "thread_count": 16, + "create_table_thread_count": 7, + "confirm_parameter_prompt": "no", + "insert_interval": 0, + "interlace_rows": 0, + "num_of_records_per_req": 30000, + "prepared_rand": 10000, + "chinese": "no", + "databases": [ + { + "dbinfo": { + "name": "benchmark", + "drop": "no", + "replica": 1, + "precision": "ms", + "keep": 3650, + "minRows": 100, + "maxRows": 4096, + "comp": 2 + }, + "super_tables": [ + { + "name": "stb", + "child_table_exists": "no", + "childtable_count": 10000, + "childtable_prefix": "stb_", + "escape_character": "yes", + "auto_create_table": "no", + "batch_create_tbl_num": 10000, + "data_source": "rand", + "insert_mode": "taosc", + "non_stop_mode": "no", + "line_protocol": "line", + "insert_rows": 0, + "childtable_limit": 10, + "childtable_offset": 100, + "interlace_rows": 0, + "insert_interval": 0, + "partial_col_num": 0, + "disorder_ratio": 0, + "disorder_range": 1000, + "timestamp_step": 10, + "start_timestamp": "2020-10-01 00:00:00.000", + "sample_format": "csv", + "sample_file": "./sample.csv", + "use_sample_ts": "no", + "tags_file": "", + "columns": [ + { + "type": "BOOL", + "name": "bl" + }, + { + "type": "TINYINT", + "name": "i8" + }, + { + "type": "SMALLINT", + "name": "i16" + }, + { + "type": "INT", + "name": "i32" + }, + { + "type": "BIGINT", + "name": "i64" + }, + { + "type": "UTINYINT", + "name": "u8" + }, + { + "type": "USMALLINT", + "name": "u16" + }, + { + "type": "UINT", + "name": "u32" + }, + { + "type": "UBIGINT", + "name": "u64" + }, + { + "type": "FLOAT", + "name": "f32" + }, + { + "type": "DOUBLE", + "name": "d64" + }, + { + "type": "VARCHAR", + "name": "bnr", + "len": 20 + }, + { + "type": "NCHAR", + "name": "nchr", + "len": 20 + } + ], + "tags": [ + { + "type": "BOOL" + }, + { + "type": "UTINYINT" + }, + { + "type": "USMALLINT" + }, + { + "type": "UINT" + }, + { + "type": "UBIGINT" + }, + { + "type": "TINYINT" + }, + { + "type": "SMALLINT" + }, + { + "type": "INT" + }, + { + "type": "BIGINT" + }, + { + "type": "FLOAT" + }, + { + "type": "DOUBLE" + }, + { + "type": "VARCHAR", + "len": 20 + }, + { + "type": "NCHAR", + "len": 20 + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/benchmark/run_bench.sh b/benchmark/run_bench.sh new file mode 100644 index 0000000..c04fa1a --- /dev/null +++ b/benchmark/run_bench.sh @@ -0,0 +1,87 @@ +#!/bin/sh +BENCHMARK_TIMES=$1 +BATCH_TABLES=$2 +BATCH_ROWS=$3 + +REPORT_NAME="golang_run_${BENCHMARK_TIMES}" +RESULT_FOLDER="result" +INSERT_TABLE_NUM=10000 + +echo "====== starting ..." +if [ ! -d ${RESULT_FOLDER} ] +then + mkdir ${RESULT_FOLDER} +fi +echo "BENCHMARK_TIMES:${BENCHMARK_TIMES}" + +clear remaining result report +rm ./${RESULT_FOLDER}/*.md + +echo "====== preparing ..." +echo "=== build benchmark code " +rm -f benchmark +go build -o benchmark benchmark.go + +echo "=== create database for benchmark " +taos -s 'create database if not exists benchmark' + +echo "===== step 1 create tables ..." +taos -s 'drop stable if exists benchmark.stb' +taos -s 'drop stable if exists benchmark.jtb' +taosBenchmark -f ./data/only_create_table_with_normal_tag.json +taosBenchmark -f ./data/only_create_table_with_json_tag.json + +echo "===== step 2 insert data ..." +hyperfine -r ${BENCHMARK_TIMES} -L types normal,json -L tables ${INSERT_TABLE_NUM} \ + './benchmark -s insert -t {types} -b {tables}' \ + --time-unit millisecond \ + --show-output \ + --export-markdown ${RESULT_FOLDER}/${REPORT_NAME}_insert.md \ + --command-name insert_{types}_${INSERT_TABLE_NUM}_tables_${BENCHMARK_TIMES}_times + +echo "===== step 3 clean data and create tables ..." +taos -s 'drop stable if exists benchmark.stb' +taos -s 'drop stable if exists benchmark.jtb' +taosBenchmark -f ./data/only_create_table_with_normal_tag.json +taosBenchmark -f ./data/only_create_table_with_json_tag.json + +echo "===== step 4 insert data with batch ..." + hyperfine -r ${BENCHMARK_TIMES} -L rows ${BATCH_ROWS} -L tables ${BATCH_TABLES} \ + -L types normal,json \ + './benchmark -s batch -t {types} -r {rows} -b {tables}' \ + --time-unit millisecond \ + --show-output \ + --export-markdown ${RESULT_FOLDER}/${REPORT_NAME}_bath.md \ + --command-name batch_{types}_${BATCH_TABLES}_tables_${BENCHMARK_TIMES}_times + +echo "===== step 5 query..." +hyperfine -r ${BENCHMARK_TIMES} -L types normal,json \ + './benchmark -s query -t {types}' \ + --time-unit millisecond \ + --show-output \ + --export-markdown ${RESULT_FOLDER}/${REPORT_NAME}_query.md \ + --command-name query_{types}_${BENCHMARK_TIMES}_times + +echo "===== step 6 avg ..." +hyperfine -r ${BENCHMARK_TIMES} -L types normal,json \ + './benchmark -s avg -t {types}' \ + --time-unit millisecond \ + --show-output \ + --export-markdown ${RESULT_FOLDER}/${REPORT_NAME}_avg.md \ + --command-name avg_{types}_${BENCHMARK_TIMES}_times + + +echo "| Command | Mean [ms] | Min [ms] | Max [ms] | Relative |">>./${RESULT_FOLDER}/${REPORT_NAME}.md +echo "|:---|---:|---:|---:|---:|">>./${RESULT_FOLDER}/${REPORT_NAME}.md +ls ./${RESULT_FOLDER}/*.md| +while read filename; +do + sed -n '3,4p' ${filename}>>${RESULT_FOLDER}/${REPORT_NAME}.md +done + +echo "=== clean database and binary file ... " +rm -f benchmark +taos -s 'drop database benchmark' + +echo "=== benchmark done ... " +echo "=== result file:${RESULT_FOLDER}/${REPORT_NAME}.md " diff --git a/common/change.go b/common/change.go index 1e32d4c..d287162 100644 --- a/common/change.go +++ b/common/change.go @@ -1,6 +1,9 @@ package common -import "time" +import ( + "fmt" + "time" +) func TimestampConvertToTime(timestamp int64, precision int) time.Time { switch precision { @@ -11,7 +14,8 @@ func TimestampConvertToTime(timestamp int64, precision int) time.Time { case PrecisionNanoSecond: // nano-second return time.Unix(0, timestamp) default: - panic("unknown precision") + s := fmt.Sprintln("unknown precision", precision, "timestamp", timestamp) + panic(s) } } @@ -24,6 +28,7 @@ func TimeToTimestamp(t time.Time, precision int) (timestamp int64) { case PrecisionNanoSecond: return t.UnixNano() default: - panic("unknown precision") + s := fmt.Sprintln("unknown precision", precision, "time", t) + panic(s) } } diff --git a/common/const.go b/common/const.go index 7e97acf..42ea950 100644 --- a/common/const.go +++ b/common/const.go @@ -1,5 +1,7 @@ package common +import "unsafe" + const ( MaxTaosSqlLen = 1048576 DefaultUser = "root" @@ -18,7 +20,7 @@ const ( TSDB_OPTION_TIMEZONE TSDB_OPTION_CONFIGDIR TSDB_OPTION_SHELL_ACTIVITY_TIMER - TSDB_MAX_OPTIONS + TSDB_OPTION_USE_ADAPTER ) const ( @@ -106,4 +108,36 @@ const ( TMQ_RES_INVALID = -1 TMQ_RES_DATA = 1 TMQ_RES_TABLE_META = 2 + TMQ_RES_METADATA = 3 ) + +var TypeLengthMap = map[int]int{ + TSDB_DATA_TYPE_NULL: 1, + TSDB_DATA_TYPE_BOOL: 1, + TSDB_DATA_TYPE_TINYINT: 1, + TSDB_DATA_TYPE_SMALLINT: 2, + TSDB_DATA_TYPE_INT: 4, + TSDB_DATA_TYPE_BIGINT: 8, + TSDB_DATA_TYPE_FLOAT: 4, + TSDB_DATA_TYPE_DOUBLE: 8, + TSDB_DATA_TYPE_TIMESTAMP: 8, + TSDB_DATA_TYPE_UTINYINT: 1, + TSDB_DATA_TYPE_USMALLINT: 2, + TSDB_DATA_TYPE_UINT: 4, + TSDB_DATA_TYPE_UBIGINT: 8, +} + +const ( + Int8Size = unsafe.Sizeof(int8(0)) + Int16Size = unsafe.Sizeof(int16(0)) + Int32Size = unsafe.Sizeof(int32(0)) + Int64Size = unsafe.Sizeof(int64(0)) + UInt8Size = unsafe.Sizeof(uint8(0)) + UInt16Size = unsafe.Sizeof(uint16(0)) + UInt32Size = unsafe.Sizeof(uint32(0)) + UInt64Size = unsafe.Sizeof(uint64(0)) + Float32Size = unsafe.Sizeof(float32(0)) + Float64Size = unsafe.Sizeof(float64(0)) +) + +const ReqIDKey = "taos_req_id" diff --git a/common/param/param.go b/common/param/param.go index 8e2abe3..a9ec02d 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -32,7 +32,6 @@ func (p *Param) SetNull(offset int) { return } p.value[offset] = nil - return } func (p *Param) SetTinyint(offset int, value int) { @@ -286,3 +285,12 @@ func (p *Param) AddJson(value []byte) *Param { func (p *Param) GetValues() []driver.Value { return p.value } + +func (p *Param) AddValue(value interface{}) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = value + p.offset += 1 + return p +} diff --git a/common/parser/block.go b/common/parser/block.go new file mode 100644 index 0000000..f573804 --- /dev/null +++ b/common/parser/block.go @@ -0,0 +1,385 @@ +package parser + +import ( + "database/sql/driver" + "math" + "unsafe" + + "github.com/taosdata/driver-go/v3/common" +) + +const ( + Int8Size = common.Int8Size + Int16Size = common.Int16Size + Int32Size = common.Int32Size + Int64Size = common.Int64Size + UInt8Size = common.UInt8Size + UInt16Size = common.UInt16Size + UInt32Size = common.UInt32Size + UInt64Size = common.UInt64Size + Float32Size = common.Float32Size + Float64Size = common.Float64Size +) + +const ( + ColInfoSize = Int8Size + Int32Size + RawBlockVersionOffset = 0 + RawBlockLengthOffset = RawBlockVersionOffset + Int32Size + NumOfRowsOffset = RawBlockLengthOffset + Int32Size + NumOfColsOffset = NumOfRowsOffset + Int32Size + HasColumnSegmentOffset = NumOfColsOffset + Int32Size + GroupIDOffset = HasColumnSegmentOffset + Int32Size + ColInfoOffset = GroupIDOffset + UInt64Size +) + +func RawBlockGetVersion(rawBlock unsafe.Pointer) int32 { + return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockVersionOffset))) +} + +func RawBlockGetLength(rawBlock unsafe.Pointer) int32 { + return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockLengthOffset))) +} + +func RawBlockGetNumOfRows(rawBlock unsafe.Pointer) int32 { + return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfRowsOffset))) +} + +func RawBlockGetNumOfCols(rawBlock unsafe.Pointer) int32 { + return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfColsOffset))) +} + +func RawBlockGetHasColumnSegment(rawBlock unsafe.Pointer) int32 { + return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + HasColumnSegmentOffset))) +} + +func RawBlockGetGroupID(rawBlock unsafe.Pointer) uint64 { + return *((*uint64)(unsafe.Pointer(uintptr(rawBlock) + GroupIDOffset))) +} + +type RawBlockColInfo struct { + ColType int8 + Bytes int32 +} + +func RawBlockGetColInfo(rawBlock unsafe.Pointer, infos []RawBlockColInfo) { + for i := 0; i < len(infos); i++ { + offset := uintptr(rawBlock) + ColInfoOffset + ColInfoSize*uintptr(i) + infos[i].ColType = *((*int8)(unsafe.Pointer(offset))) + infos[i].Bytes = *((*int32)(unsafe.Pointer(offset + Int8Size))) + } +} + +func RawBlockGetColumnLengthOffset(colCount int) uintptr { + return ColInfoOffset + uintptr(colCount)*ColInfoSize +} + +func RawBlockGetColDataOffset(colCount int) uintptr { + return ColInfoOffset + uintptr(colCount)*ColInfoSize + uintptr(colCount)*Int32Size +} + +type FormatTimeFunc func(ts int64, precision int) driver.Value + +func IsVarDataType(colType uint8) bool { + return colType == common.TSDB_DATA_TYPE_BINARY || colType == common.TSDB_DATA_TYPE_NCHAR || colType == common.TSDB_DATA_TYPE_JSON +} + +func BitmapLen(n int) int { + return ((n) + ((1 << 3) - 1)) >> 3 +} + +func BitPos(n int) int { + return n & ((1 << 3) - 1) +} + +func CharOffset(n int) int { + return n >> 3 +} + +func BMIsNull(c byte, n int) bool { + return c&(1<<(7-BitPos(n))) == (1 << (7 - BitPos(n))) +} + +type rawConvertFunc func(pStart uintptr, row int, arg ...interface{}) driver.Value + +type rawConvertVarDataFunc func(pHeader, pStart uintptr, row int) driver.Value + +var rawConvertFuncMap = map[uint8]rawConvertFunc{ + uint8(common.TSDB_DATA_TYPE_BOOL): rawConvertBool, + uint8(common.TSDB_DATA_TYPE_TINYINT): rawConvertTinyint, + uint8(common.TSDB_DATA_TYPE_SMALLINT): rawConvertSmallint, + uint8(common.TSDB_DATA_TYPE_INT): rawConvertInt, + uint8(common.TSDB_DATA_TYPE_BIGINT): rawConvertBigint, + uint8(common.TSDB_DATA_TYPE_UTINYINT): rawConvertUTinyint, + uint8(common.TSDB_DATA_TYPE_USMALLINT): rawConvertUSmallint, + uint8(common.TSDB_DATA_TYPE_UINT): rawConvertUInt, + uint8(common.TSDB_DATA_TYPE_UBIGINT): rawConvertUBigint, + uint8(common.TSDB_DATA_TYPE_FLOAT): rawConvertFloat, + uint8(common.TSDB_DATA_TYPE_DOUBLE): rawConvertDouble, + uint8(common.TSDB_DATA_TYPE_TIMESTAMP): rawConvertTime, +} + +var rawConvertVarDataMap = map[uint8]rawConvertVarDataFunc{ + uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, + uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, + uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, +} + +func ItemIsNull(pHeader uintptr, row int) bool { + offset := CharOffset(row) + c := *((*byte)(unsafe.Pointer(pHeader + uintptr(offset)))) + return BMIsNull(c, row) +} + +func rawConvertBool(pStart uintptr, row int, _ ...interface{}) driver.Value { + if (*((*byte)(unsafe.Pointer(pStart + uintptr(row)*1)))) != 0 { + return true + } else { + return false + } +} + +func rawConvertTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*int8)(unsafe.Pointer(pStart + uintptr(row)*Int8Size))) +} + +func rawConvertSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*int16)(unsafe.Pointer(pStart + uintptr(row)*Int16Size))) +} + +func rawConvertInt(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*int32)(unsafe.Pointer(pStart + uintptr(row)*Int32Size))) +} + +func rawConvertBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))) +} + +func rawConvertUTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*uint8)(unsafe.Pointer(pStart + uintptr(row)*UInt8Size))) +} + +func rawConvertUSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*uint16)(unsafe.Pointer(pStart + uintptr(row)*UInt16Size))) +} + +func rawConvertUInt(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*uint32)(unsafe.Pointer(pStart + uintptr(row)*UInt32Size))) +} + +func rawConvertUBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { + return *((*uint64)(unsafe.Pointer(pStart + uintptr(row)*UInt64Size))) +} + +func rawConvertFloat(pStart uintptr, row int, _ ...interface{}) driver.Value { + return math.Float32frombits(*((*uint32)(unsafe.Pointer(pStart + uintptr(row)*Float32Size)))) +} + +func rawConvertDouble(pStart uintptr, row int, _ ...interface{}) driver.Value { + return math.Float64frombits(*((*uint64)(unsafe.Pointer(pStart + uintptr(row)*Float64Size)))) +} + +func rawConvertTime(pStart uintptr, row int, arg ...interface{}) driver.Value { + if len(arg) == 1 { + return common.TimestampConvertToTime(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) + } else if len(arg) == 2 { + return arg[1].(FormatTimeFunc)(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) + } else { + panic("convertTime error") + } +} + +func rawConvertBinary(pHeader, pStart uintptr, row int) driver.Value { + offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := unsafe.Pointer(pStart + uintptr(offset)) + clen := *((*int16)(currentRow)) + currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + + binaryVal := make([]byte, clen) + + for index := int16(0); index < clen; index++ { + binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + } + return string(binaryVal[:]) +} + +func rawConvertNchar(pHeader, pStart uintptr, row int) driver.Value { + offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := unsafe.Pointer(pStart + uintptr(offset)) + clen := *((*int16)(currentRow)) / 4 + currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + + binaryVal := make([]rune, clen) + + for index := int16(0); index < clen; index++ { + binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) + } + return string(binaryVal) +} + +func rawConvertJson(pHeader, pStart uintptr, row int) driver.Value { + offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := unsafe.Pointer(pStart + uintptr(offset)) + clen := *((*int16)(currentRow)) + currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + + binaryVal := make([]byte, clen) + + for index := int16(0); index < clen; index++ { + binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + } + return binaryVal[:] +} + +// ReadBlock in-place +func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int) [][]driver.Value { + r := make([][]driver.Value, blockSize) + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) + var pStart uintptr + for column := 0; column < colCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataMap[colTypes[column]] + pStart = pHeader + Int32Size*uintptr(blockSize) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + r[row][column] = convertF(pHeader, pStart, row) + } + } else { + convertF := rawConvertFuncMap[colTypes[column]] + pStart = pHeader + nullBitMapOffset + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + if ItemIsNull(pHeader, row) { + r[row][column] = nil + } else { + r[row][column] = convertF(pStart, row, precision) + } + } + } + pHeader = pStart + uintptr(colLength) + } + return r +} + +func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, colTypes []uint8, precision int) { + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) + var pStart uintptr + for column := 0; column < colCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataMap[colTypes[column]] + pStart = pHeader + Int32Size*uintptr(blockSize) + dest[column] = convertF(pHeader, pStart, row) + } else { + convertF := rawConvertFuncMap[colTypes[column]] + pStart = pHeader + nullBitMapOffset + if ItemIsNull(pHeader, row) { + dest[column] = nil + } else { + dest[column] = convertF(pStart, row, precision) + } + } + pHeader = pStart + uintptr(colLength) + } +} + +func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int, formatFunc FormatTimeFunc) [][]driver.Value { + r := make([][]driver.Value, blockSize) + colCount := len(colTypes) + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(colCount) + pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) + var pStart uintptr + for column := 0; column < colCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + if IsVarDataType(colTypes[column]) { + convertF := rawConvertVarDataMap[colTypes[column]] + pStart = pHeader + uintptr(4*blockSize) + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + r[row][column] = convertF(pHeader, pStart, row) + } + } else { + convertF := rawConvertFuncMap[colTypes[column]] + pStart = pHeader + nullBitMapOffset + for row := 0; row < blockSize; row++ { + if column == 0 { + r[row] = make([]driver.Value, colCount) + } + if ItemIsNull(pHeader, row) { + r[row][column] = nil + } else { + r[row][column] = convertF(pStart, row, precision, formatFunc) + } + } + } + pHeader = pStart + uintptr(colLength) + } + return r +} + +func ItemRawBlock(colType uint8, pHeader, pStart uintptr, row int, precision int, timeFormat FormatTimeFunc) driver.Value { + if IsVarDataType(colType) { + switch colType { + case uint8(common.TSDB_DATA_TYPE_BINARY): + return rawConvertBinary(pHeader, pStart, row) + case uint8(common.TSDB_DATA_TYPE_NCHAR): + return rawConvertNchar(pHeader, pStart, row) + case uint8(common.TSDB_DATA_TYPE_JSON): + return rawConvertJson(pHeader, pStart, row) + } + } else { + if ItemIsNull(pHeader, row) { + return nil + } else { + switch colType { + case uint8(common.TSDB_DATA_TYPE_BOOL): + return rawConvertBool(pStart, row) + case uint8(common.TSDB_DATA_TYPE_TINYINT): + return rawConvertTinyint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_SMALLINT): + return rawConvertSmallint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_INT): + return rawConvertInt(pStart, row) + case uint8(common.TSDB_DATA_TYPE_BIGINT): + return rawConvertBigint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_UTINYINT): + return rawConvertUTinyint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_USMALLINT): + return rawConvertUSmallint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_UINT): + return rawConvertUInt(pStart, row) + case uint8(common.TSDB_DATA_TYPE_UBIGINT): + return rawConvertUBigint(pStart, row) + case uint8(common.TSDB_DATA_TYPE_FLOAT): + return rawConvertFloat(pStart, row) + case uint8(common.TSDB_DATA_TYPE_DOUBLE): + return rawConvertDouble(pStart, row) + case uint8(common.TSDB_DATA_TYPE_TIMESTAMP): + return rawConvertTime(pStart, row, precision, timeFormat) + } + } + } + return nil +} diff --git a/common/parser/block_test.go b/common/parser/block_test.go new file mode 100644 index 0000000..bd327de --- /dev/null +++ b/common/parser/block_test.go @@ -0,0 +1,773 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/wrapper" +) + +func TestReadBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + defer func() { + res := wrapper.TaosQuery(conn, "drop database if exists test_block_raw_parser") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res := wrapper.TaosQuery(conn, "create database if not exists test_block_raw_parser") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + res = wrapper.TaosQuery(conn, "drop table if exists test_block_raw_parser.all_type2") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + res = wrapper.TaosQuery(conn, "create table if not exists test_block_raw_parser.all_type2 (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ")") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_block_raw_parser.all_type2 values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_block_raw_parser.all_type2" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + pHeaderList := make([]uintptr, fileCount) + pStartList := make([]uintptr, fileCount) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(fileCount) + tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) + var tmpPStart uintptr + for column := 0; column < fileCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + if IsVarDataType(rh.ColTypes[column]) { + pHeaderList[column] = tmpPHeader + tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) + pStartList[column] = tmpPStart + } else { + pHeaderList[column] = tmpPHeader + tmpPStart = tmpPHeader + nullBitMapOffset + pStartList[column] = tmpPStart + } + tmpPHeader = tmpPStart + uintptr(colLength) + } + for row := 0; row < blockSize; row++ { + rowV := make([]driver.Value, fileCount) + for column := 0; column < fileCount; column++ { + v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + rowV[column] = v + } + data = append(data, rowV) + } + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +func TestBlockTag(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + defer func() { + res := wrapper.TaosQuery(conn, "drop database if exists test_block_abc1") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res := wrapper.TaosQuery(conn, "create database if not exists test_block_abc1") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "use test_block_abc1") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists meters(ts timestamp, v int) tags(location varchar(16))") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists tb1 using meters tags('abcd')") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql := "select distinct tbname,location from meters;" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + pHeaderList := make([]uintptr, fileCount) + pStartList := make([]uintptr, fileCount) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + nullBitMapOffset := uintptr(BitmapLen(blockSize)) + lengthOffset := RawBlockGetColumnLengthOffset(fileCount) + tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) // length i32, group u64 + var tmpPStart uintptr + for column := 0; column < fileCount; column++ { + colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + if IsVarDataType(rh.ColTypes[column]) { + pHeaderList[column] = tmpPHeader + tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) + pStartList[column] = tmpPStart + } else { + pHeaderList[column] = tmpPHeader + tmpPStart = tmpPHeader + nullBitMapOffset + pStartList[column] = tmpPStart + } + tmpPHeader = tmpPStart + uintptr(colLength) + } + for row := 0; row < blockSize; row++ { + rowV := make([]driver.Value, fileCount) + for column := 0; column < fileCount; column++ { + v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + rowV[column] = v + } + data = append(data, rowV) + } + } + wrapper.TaosFreeResult(res) + t.Log(data) + t.Log(len(data[0][1].(string))) +} + +func TestReadRow(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists test_read_row") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists test_read_row") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database test_read_row") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists test_read_row.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_read_row.t0 using test_read_row.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_read_row.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + for i := 0; i < blockSize; i++ { + tmp := make([]driver.Value, fileCount) + ReadRow(tmp, block, blockSize, i, rh.ColTypes, precision) + data = append(data, tmp) + } + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) +} + +func TestReadBlockWithTimeFormat(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists test_read_block_tf") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists test_read_block_tf") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database test_read_block_tf") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists test_read_block_tf.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_read_block_tf.t0 using test_read_block_tf.all_type tags('{\"a\":1}') values('%s',false,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from test_read_block_tf.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + data = ReadBlockWithTimeFormat(block, blockSize, rh.ColTypes, precision, func(ts int64, precision int) driver.Value { + return common.TimestampConvertToTime(ts, precision) + }) + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, false, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) +} + +func TestParseBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer wrapper.TaosClose(conn) + res := wrapper.TaosQuery(conn, "drop database if exists parse_block") + code := wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + defer func() { + res = wrapper.TaosQuery(conn, "drop database if exists parse_block") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + }() + res = wrapper.TaosQuery(conn, "create database parse_block vgroups 1") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + res = wrapper.TaosQuery(conn, "create table if not exists parse_block.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into parse_block.t0 using parse_block.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + wrapper.TaosFreeResult(res) + + sql = "select * from parse_block.all_type" + res = wrapper.TaosQuery(conn, sql) + code = wrapper.TaosError(res) + if code != 0 { + errStr := wrapper.TaosErrorStr(res) + wrapper.TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := wrapper.TaosNumFields(res) + rh, err := wrapper.ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := wrapper.TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := wrapper.TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + wrapper.TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + version := RawBlockGetVersion(block) + assert.Equal(t, int32(1), version) + length := RawBlockGetLength(block) + assert.Equal(t, int32(374), length) + rows := RawBlockGetNumOfRows(block) + assert.Equal(t, int32(2), rows) + columns := RawBlockGetNumOfCols(block) + assert.Equal(t, int32(15), columns) + hasColumnSegment := RawBlockGetHasColumnSegment(block) + assert.Equal(t, int32(-2147483648), hasColumnSegment) + groupId := RawBlockGetGroupID(block) + assert.Equal(t, uint64(0), groupId) + infos := make([]RawBlockColInfo, columns) + RawBlockGetColInfo(block, infos) + assert.Equal( + t, + []RawBlockColInfo{ + { + ColType: 9, + Bytes: 8, + }, + { + ColType: 1, + Bytes: 1, + }, + { + ColType: 2, + Bytes: 1, + }, + { + ColType: 3, + Bytes: 2, + }, + { + ColType: 4, + Bytes: 4, + }, + { + ColType: 5, + Bytes: 8, + }, + { + ColType: 11, + Bytes: 1, + }, + { + ColType: 12, + Bytes: 2, + }, + { + ColType: 13, + Bytes: 4, + }, + { + ColType: 14, + Bytes: 8, + }, + { + ColType: 6, + Bytes: 4, + }, + { + ColType: 7, + Bytes: 8, + }, + { + ColType: 8, + Bytes: 22, + }, + { + ColType: 10, + Bytes: 82, + }, + { + ColType: 15, + Bytes: 16384, + }, + }, + infos, + ) + d := ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + wrapper.TaosFreeResult(res) + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) +} diff --git a/common/reqid.go b/common/reqid.go new file mode 100644 index 0000000..f5d711f --- /dev/null +++ b/common/reqid.go @@ -0,0 +1,80 @@ +package common + +import ( + "math/bits" + "os" + "sync/atomic" + "time" + "unsafe" + + "github.com/google/uuid" +) + +var tUUIDHashId int64 +var serialNo int64 +var pid int64 + +func init() { + var tUUID = uuid.New().String() + tUUIDHashId = (int64(murmurHash32([]byte(tUUID), uint32(len(tUUID)))) & 0x07ff) << 52 + pid = (int64(os.Getpid()) & 0x0f) << 48 +} + +func GetReqID() int64 { + ts := (time.Now().UnixNano() / 1e6) >> 8 + val := atomic.AddInt64(&serialNo, 1) + return tUUIDHashId | pid | ((ts & 0x3ffffff) << 20) | (val & 0xfffff) +} + +const ( + c1 uint32 = 0xcc9e2d51 + c2 uint32 = 0x1b873593 +) + +// MurmurHash32 returns the MurmurHash3 sum of data. +func murmurHash32(data []byte, seed uint32) uint32 { + h1 := seed + + nBlocks := len(data) / 4 + p := uintptr(unsafe.Pointer(&data[0])) + p1 := p + uintptr(4*nBlocks) + for ; p < p1; p += 4 { + k1 := *(*uint32)(unsafe.Pointer(p)) + + k1 *= c1 + k1 = bits.RotateLeft32(k1, 15) + k1 *= c2 + + h1 ^= k1 + h1 = bits.RotateLeft32(h1, 13) + h1 = h1*4 + h1 + 0xe6546b64 + } + + tail := data[nBlocks*4:] + + var k1 uint32 + switch len(tail) & 3 { + case 3: + k1 ^= uint32(tail[2]) << 16 + fallthrough + case 2: + k1 ^= uint32(tail[1]) << 8 + fallthrough + case 1: + k1 ^= uint32(tail[0]) + k1 *= c1 + k1 = bits.RotateLeft32(k1, 15) + k1 *= c2 + h1 ^= k1 + } + + h1 ^= uint32(len(data)) + + h1 ^= h1 >> 16 + h1 *= 0x85ebca6b + h1 ^= h1 >> 13 + h1 *= 0xc2b2ae35 + h1 ^= h1 >> 16 + + return h1 +} diff --git a/common/reqid_test.go b/common/reqid_test.go new file mode 100644 index 0000000..58e2fef --- /dev/null +++ b/common/reqid_test.go @@ -0,0 +1,35 @@ +package common + +import ( + "testing" +) + +func BenchmarkGetReqID(b *testing.B) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + GetReqID() + } + }) +} + +func BenchmarkGetReqIDParallel(b *testing.B) { + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + GetReqID() + } + }) +} + +func TestGetReqID(t *testing.T) { + t.Log(GetReqID()) +} + +func TestMurmurHash(t *testing.T) { + if murmurHash32([]byte("driver-go"), 0) != 3037880692 { + t.Fatal("fail") + } +} diff --git a/common/serializer/block.go b/common/serializer/block.go new file mode 100644 index 0000000..4a23b53 --- /dev/null +++ b/common/serializer/block.go @@ -0,0 +1,496 @@ +package serializer + +import ( + "bytes" + "errors" + "math" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + taosTypes "github.com/taosdata/driver-go/v3/types" +) + +const ( + Int16Size = int(common.Int16Size) + Int32Size = int(common.Int32Size) + Int64Size = int(common.Int64Size) + UInt16Size = int(common.UInt16Size) + UInt32Size = int(common.UInt32Size) + UInt64Size = int(common.UInt64Size) + Float32Size = int(common.Float32Size) + Float64Size = int(common.Float64Size) +) + +func BitmapLen(n int) int { + return ((n) + ((1 << 3) - 1)) >> 3 +} + +func BitPos(n int) int { + return n & ((1 << 3) - 1) +} + +func CharOffset(n int) int { + return n >> 3 +} + +func BMSetNull(c byte, n int) byte { + return c + (1 << (7 - BitPos(n))) +} + +var ColumnNumerNotMatch = errors.New("number of columns does not match") +var DataTypeWrong = errors.New("wrong data type") + +func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { + columns := len(params) + rows := len(params[0].GetValues()) + colTypes, err := colType.GetValue() + if err != nil { + return nil, err + } + if len(colTypes) != columns { + return nil, ColumnNumerNotMatch + } + var block []byte + //version int32 + block = appendUint32(block, uint32(1)) + //length int32 + block = appendUint32(block, uint32(0)) + //rows int32 + block = appendUint32(block, uint32(rows)) + //columns int32 + block = appendUint32(block, uint32(columns)) + //flagSegment int32 + block = appendUint32(block, uint32(0)) + //groupID uint64 + block = appendUint64(block, uint64(0)) + colInfoData := make([]byte, 0, 5*columns) + lengthData := make([]byte, 0, 4*columns) + bitMapLen := BitmapLen(rows) + var data []byte + //colInfo(type+bytes) (int8+int32) * columns + buffer := bytes.NewBuffer(block) + for colIndex := 0; colIndex < columns; colIndex++ { + switch colTypes[colIndex].Type { + case taosTypes.TaosBoolType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BOOL) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_BOOL] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBool) + if !is { + return nil, DataTypeWrong + } + if v { + dataTmp[rowIndex+bitMapLen] = 1 + } + } + } + data = append(data, dataTmp...) + case taosTypes.TaosTinyintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_TINYINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_TINYINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosTinyint) + if !is { + return nil, DataTypeWrong + } + dataTmp[rowIndex+bitMapLen] = byte(v) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosSmallintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_SMALLINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_SMALLINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int16Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosSmallint) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*Int16Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosIntType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_INT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_INT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosInt) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*Int32Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosBigintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BIGINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_BIGINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBigint) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*Int64Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + dataTmp[offset+4] = byte(v >> 32) + dataTmp[offset+5] = byte(v >> 40) + dataTmp[offset+6] = byte(v >> 48) + dataTmp[offset+7] = byte(v >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUTinyintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UTINYINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UTINYINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUTinyint) + if !is { + return nil, DataTypeWrong + } + dataTmp[rowIndex+bitMapLen] = uint8(v) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUSmallintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_USMALLINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_USMALLINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt16Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUSmallint) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*UInt16Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosUIntType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUInt) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*UInt32Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosUBigintType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_UBIGINT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_UBIGINT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*UInt64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosUBigint) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*UInt64Size + bitMapLen + dataTmp[offset] = byte(v) + dataTmp[offset+1] = byte(v >> 8) + dataTmp[offset+2] = byte(v >> 16) + dataTmp[offset+3] = byte(v >> 24) + dataTmp[offset+4] = byte(v >> 32) + dataTmp[offset+5] = byte(v >> 40) + dataTmp[offset+6] = byte(v >> 48) + dataTmp[offset+7] = byte(v >> 56) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosFloatType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_FLOAT) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_FLOAT] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Float32Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosFloat) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*Float32Size + bitMapLen + vv := math.Float32bits(float32(v)) + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + } + } + data = append(data, dataTmp...) + + case taosTypes.TaosDoubleType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_DOUBLE) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_DOUBLE] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Float64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosDouble) + if !is { + return nil, DataTypeWrong + } + offset := rowIndex*Float64Size + bitMapLen + vv := math.Float64bits(float64(v)) + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + dataTmp[offset+4] = byte(vv >> 32) + dataTmp[offset+5] = byte(vv >> 40) + dataTmp[offset+6] = byte(vv >> 48) + dataTmp[offset+7] = byte(vv >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosBinaryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_BINARY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosBinary) + if !is { + return nil, DataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosNcharType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_NCHAR) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosNchar) + if !is { + return nil, DataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + rs := []rune(v) + dataTmp = appendUint16(dataTmp, uint16(len(rs)*4)) + for _, r := range rs { + dataTmp = appendUint32(dataTmp, uint32(r)) + } + length += len(rs)*4 + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosTimestampType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_TIMESTAMP) + length := common.TypeLengthMap[common.TSDB_DATA_TYPE_TIMESTAMP] + colInfoData = appendUint32(colInfoData, uint32(length)) + lengthData = appendUint32(lengthData, uint32(length*rows)) + dataTmp := make([]byte, bitMapLen+rows*Int64Size) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + if rowData[rowIndex] == nil { + charOffset := CharOffset(rowIndex) + dataTmp[charOffset] = BMSetNull(dataTmp[charOffset], rowIndex) + } else { + v, is := rowData[rowIndex].(taosTypes.TaosTimestamp) + if !is { + return nil, DataTypeWrong + } + vv := common.TimeToTimestamp(v.T, v.Precision) + offset := rowIndex*Int64Size + bitMapLen + dataTmp[offset] = byte(vv) + dataTmp[offset+1] = byte(vv >> 8) + dataTmp[offset+2] = byte(vv >> 16) + dataTmp[offset+3] = byte(vv >> 24) + dataTmp[offset+4] = byte(vv >> 32) + dataTmp[offset+5] = byte(vv >> 40) + dataTmp[offset+6] = byte(vv >> 48) + dataTmp[offset+7] = byte(vv >> 56) + } + } + data = append(data, dataTmp...) + case taosTypes.TaosJsonType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_JSON) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosJson) + if !is { + return nil, DataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + } + } + buffer.Write(colInfoData) + buffer.Write(lengthData) + buffer.Write(data) + block = buffer.Bytes() + for i := 0; i < Int32Size; i++ { + block[4+i] = byte(len(block) >> (8 * i)) + } + return block, nil +} + +func appendUint16(b []byte, v uint16) []byte { + return append(b, + byte(v), + byte(v>>8), + ) +} + +func appendUint32(b []byte, v uint32) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + ) +} + +func appendUint64(b []byte, v uint64) []byte { + return append(b, + byte(v), + byte(v>>8), + byte(v>>16), + byte(v>>24), + byte(v>>32), + byte(v>>40), + byte(v>>48), + byte(v>>56), + ) +} diff --git a/common/serializer/block_test.go b/common/serializer/block_test.go new file mode 100644 index 0000000..6023592 --- /dev/null +++ b/common/serializer/block_test.go @@ -0,0 +1,324 @@ +package serializer + +import ( + "math" + "reflect" + "testing" + "time" + + "github.com/taosdata/driver-go/v3/common/param" +) + +func TestSerializeRawBlock(t *testing.T) { + type args struct { + params []*param.Param + colType *param.ColumnType + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "all type", + args: args{ + params: []*param.Param{ + param.NewParam(1).AddTimestamp(time.Unix(0, 0), 0), + param.NewParam(1).AddBool(true), + param.NewParam(1).AddTinyint(127), + param.NewParam(1).AddSmallint(32767), + param.NewParam(1).AddInt(2147483647), + param.NewParam(1).AddBigint(9223372036854775807), + param.NewParam(1).AddUTinyint(255), + param.NewParam(1).AddUSmallint(65535), + param.NewParam(1).AddUInt(4294967295), + param.NewParam(1).AddUBigint(18446744073709551615), + param.NewParam(1).AddFloat(math.MaxFloat32), + param.NewParam(1).AddDouble(math.MaxFloat64), + param.NewParam(1).AddBinary([]byte("ABC")), + param.NewParam(1).AddNchar("涛思数据"), + }, + colType: param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + }, + want: []byte{ + 0x01, 0x00, 0x00, 0x00, //version + 0xf8, 0x00, 0x00, 0x00, //length + 0x01, 0x00, 0x00, 0x00, //rows + 0x0e, 0x00, 0x00, 0x00, //columns + 0x00, 0x00, 0x00, 0x00, //flagSegment + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //groupID + //types + 0x09, 0x08, 0x00, 0x00, 0x00, //1 + 0x01, 0x01, 0x00, 0x00, 0x00, //2 + 0x02, 0x01, 0x00, 0x00, 0x00, //3 + 0x03, 0x02, 0x00, 0x00, 0x00, //4 + 0x04, 0x04, 0x00, 0x00, 0x00, //5 + 0x05, 0x08, 0x00, 0x00, 0x00, //6 + 0x0b, 0x01, 0x00, 0x00, 0x00, //7 + 0x0c, 0x02, 0x00, 0x00, 0x00, //8 + 0x0d, 0x04, 0x00, 0x00, 0x00, //9 + 0x0e, 0x08, 0x00, 0x00, 0x00, //10 + 0x06, 0x04, 0x00, 0x00, 0x00, //11 + 0x07, 0x08, 0x00, 0x00, 0x00, //12 + 0x08, 0x00, 0x00, 0x00, 0x00, //13 + 0x0a, 0x00, 0x00, 0x00, 0x00, //14 + //lengths + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, + 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //ts + 0x00, + 0x01, //bool + 0x00, + 0x7f, //i8 + 0x00, + 0xff, 0x7f, //i16 + 0x00, + 0xff, 0xff, 0xff, 0x7f, //i32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, //i64 + 0x00, + 0xff, //u8 + 0x00, + 0xff, 0xff, //u16 + 0x00, + 0xff, 0xff, 0xff, 0xff, //u32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, //u64 + 0x00, + 0xff, 0xff, 0x7f, 0x7f, //f32 + 0x00, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0x7f, //f64 + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, //binary + 0x41, 0x42, 0x43, + 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, //nchar + 0x9b, 0x6d, 0x00, 0x00, 0x1d, 0x60, 0x00, 0x00, 0x70, 0x65, 0x00, 0x00, 0x6e, 0x63, 0x00, 0x00, + }, + wantErr: false, + }, + { + name: "all with nil", + args: args{ + params: []*param.Param{ + param.NewParam(3).AddTimestamp(time.Unix(1666248065, 0), 0).AddNull().AddTimestamp(time.Unix(1666248067, 0), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + param.NewParam(3).AddJson([]byte("{\"a\":1}")).AddNull().AddJson([]byte("{\"a\":1}")), + }, + colType: param.NewColumnType(15). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0). + AddJson(0), + }, + want: []byte{ + 0x01, 0x00, 0x00, 0x00, + 0xec, 0x01, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + //types + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x0f, 0x00, 0x00, 0x00, 0x00, + //lengths + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, + 0x1a, 0x00, 0x00, 0x00, + 0x54, 0x00, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, + // ts + 0x40, + 0xe8, 0xbf, 0x1f, 0xf4, 0x83, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xb8, 0xc7, 0x1f, 0xf4, 0x83, 0x01, 0x00, 0x00, + + // bool + 0x40, + 0x01, + 0x00, + 0x01, + + // i8 + 0x40, + 0x01, + 0x00, + 0x01, + + //int16 + 0x40, + 0x01, 0x00, + 0x00, 0x00, + 0x01, 0x00, + + //int32 + 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + + //int64 + 0x40, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //uint8 + 0x40, + 0x01, + 0x00, + 0x01, + + //uint16 + 0x40, + 0x01, 0x00, + 0x00, 0x00, + 0x01, 0x00, + + //uint32 + 0x40, + 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + + //uint64 + 0x40, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + //float + 0x40, + 0x00, 0x00, 0x80, 0x3f, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x80, 0x3f, + + //double + 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + + //binary + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x0d, 0x00, 0x00, 0x00, + 0x0b, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x0b, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + //nchar + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x2a, 0x00, 0x00, 0x00, + 0x28, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x73, 0x00, + 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, + 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + 0x28, 0x00, + 0x74, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x73, 0x00, + 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, + 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, + 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + + //json + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x09, 0x00, 0x00, 0x00, + 0x07, 0x00, + 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, + 0x07, 0x00, + 0x7b, 0x22, 0x61, 0x22, 0x3a, 0x31, 0x7d, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := SerializeRawBlock(tt.args.params, tt.args.colType) + if (err != nil) != tt.wantErr { + t.Errorf("SerializeRawBlock() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SerializeRawBlock() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/common/sql.go b/common/sql.go index 39fa08d..abd60d2 100644 --- a/common/sql.go +++ b/common/sql.go @@ -10,7 +10,7 @@ import ( "time" ) -func InterpolateParams(query string, args []driver.Value) (string, error) { +func InterpolateParams(query string, args []driver.NamedValue) (string, error) { // Number of ? should be same to len(args) if strings.Count(query, "?") != len(args) { return "", driver.ErrSkip @@ -27,7 +27,7 @@ func InterpolateParams(query string, args []driver.Value) (string, error) { buf.WriteString(query[i : i+q]) i += q - arg := args[argPos] + arg := args[argPos].Value argPos++ if arg == nil { @@ -98,3 +98,14 @@ func InterpolateParams(query string, args []driver.Value) (string, error) { } return buf.String(), nil } + +func ValueArgsToNamedValueArgs(args []driver.Value) (values []driver.NamedValue) { + values = make([]driver.NamedValue, len(args)) + for i, arg := range args { + values[i] = driver.NamedValue{ + Ordinal: i + 1, + Value: arg, + } + } + return +} diff --git a/common/sql_test.go b/common/sql_test.go index 7957ddc..543294d 100644 --- a/common/sql_test.go +++ b/common/sql_test.go @@ -12,7 +12,7 @@ import ( func TestInterpolateParams(t *testing.T) { type args struct { query string - args []driver.Value + args []driver.NamedValue } tests := []struct { name string @@ -41,24 +41,24 @@ func TestInterpolateParams(t *testing.T) { "bs = ? and " + "str = ? and " + "nil is ?", - args: []driver.Value{ - time.Unix(1643068800, 0).UTC(), - int8(1), - int16(2), - int32(3), - int64(4), - uint8(1), - uint16(2), - uint32(3), - uint64(4), - float32(5.2), - float64(5.2), - int(6), - uint(6), - bool(true), - []byte("'bytes'"), - []byte("'str'"), - nil, + args: []driver.NamedValue{ + {Ordinal: 1, Value: time.Unix(1643068800, 0).UTC()}, + {Ordinal: 2, Value: int8(1)}, + {Ordinal: 3, Value: int16(2)}, + {Ordinal: 4, Value: int32(3)}, + {Ordinal: 5, Value: int64(4)}, + {Ordinal: 6, Value: uint8(1)}, + {Ordinal: 7, Value: uint16(2)}, + {Ordinal: 8, Value: uint32(3)}, + {Ordinal: 9, Value: uint64(4)}, + {Ordinal: 10, Value: float32(5.2)}, + {Ordinal: 11, Value: float64(5.2)}, + {Ordinal: 12, Value: 6}, + {Ordinal: 13, Value: uint(6)}, + {Ordinal: 14, Value: true}, + {Ordinal: 15, Value: []byte("'bytes'")}, + {Ordinal: 16, Value: []byte("'str'")}, + {Ordinal: 17, Value: nil}, }, }, want: "select * from t1 where " + diff --git a/common/tmq.go b/common/tmq.go deleted file mode 100644 index 4d5389b..0000000 --- a/common/tmq.go +++ /dev/null @@ -1,27 +0,0 @@ -package common - -type Meta struct { - Type string `json:"type"` - TableName string `json:"tableName"` - TableType string `json:"tableType"` - Columns []struct { - Name string `json:"name"` - Type int `json:"type"` - Length int `json:"length"` - } `json:"columns"` - Using string `json:"using"` - TagNum int `json:"tagNum"` - Tags []struct { - Name string `json:"name"` - Type int `json:"type"` - Value interface{} `json:"value"` - } `json:"tags"` - TableNameList []string `json:"tableNameList"` - AlterType int `json:"alterType"` - ColName string `json:"colName"` - ColNewName string `json:"colNewName"` - ColType int `json:"colType"` - ColLength int `json:"colLength"` - ColValue string `json:"colValue"` - ColValueNull bool `json:"colValueNull"` -} diff --git a/common/tmq/config.go b/common/tmq/config.go new file mode 100644 index 0000000..b17084c --- /dev/null +++ b/common/tmq/config.go @@ -0,0 +1,34 @@ +package tmq + +import ( + "fmt" + "reflect" +) + +type ConfigValue interface{} +type ConfigMap map[string]ConfigValue + +func (m ConfigMap) Get(key string, defval ConfigValue) (ConfigValue, error) { + return m.get(key, defval) +} + +func (m ConfigMap) get(key string, defval ConfigValue) (ConfigValue, error) { + v, ok := m[key] + if !ok { + return defval, nil + } + + if defval != nil && reflect.TypeOf(defval) != reflect.TypeOf(v) { + return nil, fmt.Errorf("%s expects type %T, not %T", key, defval, v) + } + + return v, nil +} + +func (m ConfigMap) Clone() ConfigMap { + m2 := make(ConfigMap) + for k, v := range m { + m2[k] = v + } + return m2 +} diff --git a/common/tmq/event.go b/common/tmq/event.go new file mode 100644 index 0000000..e6a9e80 --- /dev/null +++ b/common/tmq/event.go @@ -0,0 +1,184 @@ +package tmq + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + + taosError "github.com/taosdata/driver-go/v3/errors" +) + +type Data struct { + TableName string + Data [][]driver.Value +} +type Event interface { + String() string +} + +type Error struct { + code int + str string +} + +const ErrorOther = 0xffff + +func NewTMQError(code int, str string) Error { + return Error{ + code: code, + str: str, + } +} + +func NewTMQErrorWithErr(err error) Error { + tErr, ok := err.(*taosError.TaosError) + if ok { + return Error{ + code: int(tErr.Code), + str: tErr.ErrStr, + } + } else { + return Error{ + code: ErrorOther, + str: err.Error(), + } + } +} + +func (e Error) String() string { + return fmt.Sprintf("[0x%x] %s", e.code, e.str) +} + +func (e Error) Error() string { + return e.String() +} + +func (e Error) Code() int { + return e.code +} + +type Message interface { + Topic() string + DBName() string + Value() interface{} +} + +type DataMessage struct { + dbName string + topic string + data []*Data +} + +func (m *DataMessage) String() string { + data, _ := json.Marshal(m.data) + return fmt.Sprintf("DataMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *DataMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *DataMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *DataMessage) SetData(data []*Data) { + m.data = data +} + +func (m *DataMessage) Topic() string { + return m.topic +} + +func (m *DataMessage) DBName() string { + return m.dbName +} + +func (m *DataMessage) Value() interface{} { + return m.data +} + +type MetaMessage struct { + dbName string + topic string + offset string + meta *Meta +} + +func (m *MetaMessage) String() string { + data, _ := json.Marshal(m.meta) + return fmt.Sprintf("MetaMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *MetaMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *MetaMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *MetaMessage) SetOffset(offset string) { + m.offset = offset +} + +func (m *MetaMessage) SetMeta(meta *Meta) { + m.meta = meta +} + +func (m *MetaMessage) Topic() string { + return m.topic +} + +func (m *MetaMessage) DBName() string { + return m.dbName +} + +func (m *MetaMessage) Value() interface{} { + return m.meta +} + +type MetaDataMessage struct { + dbName string + topic string + offset string + metaData *MetaData +} + +func (m *MetaDataMessage) String() string { + data, _ := json.Marshal(m.metaData) + return fmt.Sprintf("MetaDataMessage: %s[%s]:%s", m.topic, m.dbName, string(data)) +} + +func (m *MetaDataMessage) SetDbName(dbName string) { + m.dbName = dbName +} + +func (m *MetaDataMessage) SetTopic(topic string) { + m.topic = topic +} + +func (m *MetaDataMessage) SetOffset(offset string) { + m.offset = offset +} + +func (m *MetaDataMessage) SetMetaData(metaData *MetaData) { + m.metaData = metaData +} + +type MetaData struct { + Meta *Meta + Data []*Data +} + +func (m *MetaDataMessage) Topic() string { + return m.topic +} + +func (m *MetaDataMessage) DBName() string { + return m.dbName +} + +func (m *MetaDataMessage) Value() interface{} { + return m.metaData +} diff --git a/common/tmq/tmq.go b/common/tmq/tmq.go new file mode 100644 index 0000000..a43607b --- /dev/null +++ b/common/tmq/tmq.go @@ -0,0 +1,42 @@ +package tmq + +type Meta struct { + Type string `json:"type"` + TableName string `json:"tableName"` + TableType string `json:"tableType"` + CreateList []*CreateItem `json:"createList"` + Columns []*Column `json:"columns"` + Using string `json:"using"` + TagNum int `json:"tagNum"` + Tags []*Tag `json:"tags"` + TableNameList []string `json:"tableNameList"` + AlterType int `json:"alterType"` + ColName string `json:"colName"` + ColNewName string `json:"colNewName"` + ColType int `json:"colType"` + ColLength int `json:"colLength"` + ColValue string `json:"colValue"` + ColValueNull bool `json:"colValueNull"` +} + +type Tag struct { + Name string `json:"name"` + Type int `json:"type"` + Value interface{} `json:"value"` +} + +type Column struct { + Name string `json:"name"` + Type int `json:"type"` + Length int `json:"length"` +} + +type CreateItem struct { + TableName string `json:"tableName"` + Using string `json:"using"` + TagNum int `json:"tagNum"` + Tags []*Tag `json:"tags"` +} + +type TopicPartition struct { +} diff --git a/common/tmq/tmq_test.go b/common/tmq/tmq_test.go new file mode 100644 index 0000000..c15999b --- /dev/null +++ b/common/tmq/tmq_test.go @@ -0,0 +1,62 @@ +package tmq + +import ( + "encoding/json" + "testing" +) + +const createJson = `{ + "type": "create", + "tableName": "t1", + "tableType": "super", + "columns": [ + { + "name": "c1", + "type": 0, + "length": 0 + }, + { + "name": "c2", + "type": 8, + "length": 8 + } + ], + "tags": [ + { + "name": "t1", + "type": 0, + "length": 0 + }, + { + "name": "t2", + "type": 8, + "length": 8 + } + ] +}` +const dropJson = `{ + "type":"drop", + "tableName":"t1", + "tableType":"super", + "tableNameList":["t1", "t2"] +}` + +func TestCreateJson(t *testing.T) { + var obj Meta + err := json.Unmarshal([]byte(createJson), &obj) + if err != nil { + t.Log(err) + return + } + t.Log(obj) +} + +func TestDropJson(t *testing.T) { + var obj Meta + err := json.Unmarshal([]byte(dropJson), &obj) + if err != nil { + t.Log(err) + return + } + t.Log(obj) +} diff --git a/common/ws.go b/common/ws.go new file mode 100644 index 0000000..c3a215f --- /dev/null +++ b/common/ws.go @@ -0,0 +1,25 @@ +package common + +import ( + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + BufferSize4M = 4 * 1024 * 1024 + DefaultMessageTimeout = time.Minute * 5 + DefaultPongWait = 60 * time.Second + DefaultPingPeriod = (60 * time.Second * 9) / 10 + DefaultWriteWait = 10 * time.Second +) + +var DefaultDialer = websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + ReadBufferSize: BufferSize4M, + WriteBufferSize: BufferSize4M, + WriteBufferPool: &sync.Pool{}, +} diff --git a/examples/stmtoverws/main.go b/examples/stmtoverws/main.go new file mode 100644 index 0000000..a9c0d5a --- /dev/null +++ b/examples/stmtoverws/main.go @@ -0,0 +1,267 @@ +package main + +import ( + "database/sql" + "fmt" + "time" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + _ "github.com/taosdata/driver-go/v3/taosRestful" + "github.com/taosdata/driver-go/v3/ws/stmt" +) + +func main() { + db, err := sql.Open("taosRestful", "root:taosdata@http(localhost:6041)/") + if err != nil { + panic(err) + } + defer db.Close() + prepareEnv(db) + + config := stmt.NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetConnectDB("example_ws_stmt") + config.SetMessageTimeout(common.DefaultMessageTimeout) + config.SetWriteWait(common.DefaultWriteWait) + config.SetErrorHandler(func(connector *stmt.Connector, err error) { + panic(err) + }) + config.SetCloseHandler(func() { + fmt.Println("stmt connector closed") + }) + + connector, err := stmt.NewConnector(config) + if err != nil { + panic(err) + } + now := time.Now() + { + stmt, err := connector.Init() + if err != nil { + panic(err) + } + err = stmt.Prepare("insert into ? using all_json tags(?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + panic(err) + } + err = stmt.SetTableName("tb1") + if err != nil { + panic(err) + } + err = stmt.SetTags(param.NewParam(1).AddJson([]byte(`{"tb":1}`)), param.NewColumnType(1).AddJson(0)) + if err != nil { + panic(err) + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + panic(err) + } + err = stmt.AddBatch() + if err != nil { + panic(err) + } + err = stmt.Exec() + if err != nil { + panic(err) + } + affected := stmt.GetAffectedRows() + fmt.Println("all_json affected rows:", affected) + err = stmt.Close() + if err != nil { + panic(err) + } + } + { + stmt, err := connector.Init() + if err != nil { + panic(err) + } + err = stmt.Prepare("insert into ? using all_all tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + err = stmt.SetTableName("tb1") + if err != nil { + panic(err) + } + + err = stmt.SetTableName("tb2") + if err != nil { + panic(err) + } + err = stmt.SetTags( + param.NewParam(14). + AddTimestamp(now, 0). + AddBool(true). + AddTinyint(2). + AddSmallint(2). + AddInt(2). + AddBigint(2). + AddUTinyint(2). + AddUSmallint(2). + AddUInt(2). + AddUBigint(2). + AddFloat(2). + AddDouble(2). + AddBinary([]byte("tb2")). + AddNchar("tb2"), + param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + ) + if err != nil { + panic(err) + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + panic(err) + } + err = stmt.AddBatch() + if err != nil { + panic(err) + } + err = stmt.Exec() + if err != nil { + panic(err) + } + affected := stmt.GetAffectedRows() + fmt.Println("all_all affected rows:", affected) + err = stmt.Close() + if err != nil { + panic(err) + } + + } +} + +func prepareEnv(db *sql.DB) { + steps := []string{ + "create database example_ws_stmt", + "create table example_ws_stmt.all_json(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(t json)", + "create table example_ws_stmt.all_all(" + + "ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(" + + "tts timestamp," + + "tc1 bool," + + "tc2 tinyint," + + "tc3 smallint," + + "tc4 int," + + "tc5 bigint," + + "tc6 tinyint unsigned," + + "tc7 smallint unsigned," + + "tc8 int unsigned," + + "tc9 bigint unsigned," + + "tc10 float," + + "tc11 double," + + "tc12 binary(20)," + + "tc13 nchar(20))", + } + for _, step := range steps { + _, err := db.Exec(step) + if err != nil { + panic(err) + } + } +} diff --git a/examples/taosWS/main.go b/examples/taosWS/main.go new file mode 100644 index 0000000..cc6a598 --- /dev/null +++ b/examples/taosWS/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "database/sql" + "fmt" + "time" + + _ "github.com/taosdata/driver-go/v3/taosWS" +) + +func main() { + db, err := sql.Open("taosWS", "root:taosdata@ws(127.0.0.1:6041)/") + if err != nil { + panic(err) + } + defer db.Close() + _, err = db.Exec("create database if not exists example_taos_ws") + if err != nil { + panic(err) + } + _, err = db.Exec("create table if not exists example_taos_ws.stb(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ") tags (info json)") + if err != nil { + panic(err) + } + _, err = db.Exec("create table if not exists example_taos_ws.tb1 using example_taos_ws.stb tags ('{\"name\":\"tb1\"}')") + if err != nil { + panic(err) + } + now := time.Now() + _, err = db.Exec(fmt.Sprintf("insert into example_taos_ws.tb1 values ('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + panic(err) + } + rows, err := db.Query(fmt.Sprintf("select * from example_taos_ws.tb1 where ts = '%s'", now.Format(time.RFC3339Nano))) + if err != nil { + panic(err) + } + for rows.Next() { + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + ) + if err != nil { + panic(err) + } + fmt.Println("ts:", ts.Local()) + fmt.Println("c1:", c1) + fmt.Println("c2:", c2) + fmt.Println("c3:", c3) + fmt.Println("c4:", c4) + fmt.Println("c5:", c5) + fmt.Println("c6:", c6) + fmt.Println("c7:", c7) + fmt.Println("c8:", c8) + fmt.Println("c9:", c9) + fmt.Println("c10:", c10) + fmt.Println("c11:", c11) + fmt.Println("c12:", c12) + fmt.Println("c13:", c13) + } +} diff --git a/examples/tmq/main.go b/examples/tmq/main.go index 7721eed..bf52a0e 100644 --- a/examples/tmq/main.go +++ b/examples/tmq/main.go @@ -1,17 +1,12 @@ package main import ( - "context" - "encoding/json" "fmt" - "strconv" - "time" + "os" "github.com/taosdata/driver-go/v3/af" "github.com/taosdata/driver-go/v3/af/tmq" - "github.com/taosdata/driver-go/v3/common" - "github.com/taosdata/driver-go/v3/errors" - "github.com/taosdata/driver-go/v3/wrapper" + tmqcommon "github.com/taosdata/driver-go/v3/common/tmq" ) func main() { @@ -20,101 +15,59 @@ func main() { panic(err) } defer db.Close() - _, err = db.Exec("create database if not exists example_tmq") + _, err = db.Exec("create database if not exists example_tmq WAL_RETENTION_PERIOD 86400") if err != nil { panic(err) } - _, err = db.Exec("create topic if not exists example_tmq_topic with meta as DATABASE example_tmq") + _, err = db.Exec("create topic if not exists example_tmq_topic as DATABASE example_tmq") if err != nil { panic(err) } - config := tmq.NewConfig() - defer config.Destroy() - err = config.SetGroupID("test") if err != nil { panic(err) } - err = config.SetAutoOffsetReset("earliest") - if err != nil { - panic(err) - } - err = config.SetConnectIP("127.0.0.1") - if err != nil { - panic(err) - } - err = config.SetConnectUser("root") - if err != nil { - panic(err) - } - err = config.SetConnectPass("taosdata") - if err != nil { - panic(err) - } - err = config.SetConnectPort("6030") - if err != nil { - panic(err) - } - err = config.SetMsgWithTableName(true) - if err != nil { - panic(err) - } - err = config.EnableHeartBeat() - if err != nil { - panic(err) - } - err = config.EnableAutoCommit(func(result *wrapper.TMQCommitCallbackResult) { - if result.ErrCode != 0 { - errStr := wrapper.TMQErr2Str(result.ErrCode) - err := errors.NewError(int(result.ErrCode), errStr) - panic(err) - } + consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_client", + "enable.auto.commit": "false", + "enable.heartbeat.background": "true", + "experimental.snapshot.enable": "true", + "msg.with.table.name": "true", }) if err != nil { panic(err) } - consumer, err := tmq.NewConsumer(config) + err = consumer.Subscribe("example_tmq_topic", nil) if err != nil { panic(err) } - err = consumer.Subscribe([]string{"example_tmq_topic"}) + _, err = db.Exec("create table example_tmq.t1 (ts timestamp,v int)") if err != nil { panic(err) } - _, err = db.Exec("create table example_tmq.t1 (ts timestamp,v int)") + _, err = db.Exec("insert into example_tmq.t1 values(now,1)") if err != nil { panic(err) } - for { - result, err := consumer.Poll(time.Second) - if err != nil { - panic(err) + for i := 0; i < 5; i++ { + ev := consumer.Poll(0) + if ev != nil { + switch e := ev.(type) { + case *tmqcommon.DataMessage: + fmt.Println(e.Value()) + case tmqcommon.Error: + fmt.Fprintf(os.Stderr, "%% Error: %v: %v\n", e.Code(), e) + panic(e) + } } - if result.Type != common.TMQ_RES_TABLE_META { - panic("want message type 2 got " + strconv.Itoa(int(result.Type))) - } - data, _ := json.Marshal(result.Meta) - fmt.Println(string(data)) - consumer.Commit(context.Background(), result.Message) - consumer.FreeMessage(result.Message) - break } - _, err = db.Exec("insert into example_tmq.t1 values(now,1)") + err = consumer.Close() if err != nil { panic(err) } - for { - result, err := consumer.Poll(time.Second) - if err != nil { - panic(err) - } - if result.Type != common.TMQ_RES_DATA { - panic("want message type 1 got " + strconv.Itoa(int(result.Type))) - } - data, _ := json.Marshal(result.Data) - fmt.Println(string(data)) - consumer.Commit(context.Background(), result.Message) - consumer.FreeMessage(result.Message) - break - } - consumer.Close() } diff --git a/examples/tmqoverws/main.go b/examples/tmqoverws/main.go new file mode 100644 index 0000000..9d9eda3 --- /dev/null +++ b/examples/tmqoverws/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "database/sql" + "fmt" + + "github.com/taosdata/driver-go/v3/common" + tmqcommon "github.com/taosdata/driver-go/v3/common/tmq" + _ "github.com/taosdata/driver-go/v3/taosRestful" + "github.com/taosdata/driver-go/v3/ws/tmq" +) + +func main() { + db, err := sql.Open("taosRestful", "root:taosdata@http(localhost:6041)/") + if err != nil { + panic(err) + } + defer db.Close() + prepareEnv(db) + consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "example", + "client.id": "example_consumer", + "auto.offset.reset": "earliest", + }) + if err != nil { + panic(err) + } + err = consumer.Subscribe("example_ws_tmq_topic", nil) + if err != nil { + panic(err) + } + go func() { + _, err := db.Exec("create table example_ws_tmq.t_all(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + panic(err) + } + _, err = db.Exec("insert into example_ws_tmq.t_all values(now,true,2,3,4,5,6,7,8,9,10.123,11.123,'binary','nchar')") + if err != nil { + panic(err) + } + }() + for i := 0; i < 5; i++ { + ev := consumer.Poll(0) + if ev != nil { + switch e := ev.(type) { + case *tmqcommon.DataMessage: + fmt.Printf("get message:%v", e) + case tmqcommon.Error: + fmt.Printf("%% Error: %v: %v\n", e.Code(), e) + panic(e) + } + } + } + err = consumer.Close() + if err != nil { + panic(err) + } +} + +func prepareEnv(db *sql.DB) { + _, err := db.Exec("create database example_ws_tmq WAL_RETENTION_PERIOD 86400") + if err != nil { + panic(err) + } + _, err = db.Exec("create topic example_ws_tmq_topic as database example_ws_tmq") + if err != nil { + panic(err) + } +} diff --git a/go.mod b/go.mod index 01c06ed..a416a3e 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/taosdata/driver-go/v3 go 1.14 require ( + github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 github.com/json-iterator/go v1.1.12 github.com/spf13/cast v1.5.0 github.com/stretchr/testify v1.8.0 diff --git a/taosRestful/connection.go b/taosRestful/connection.go index b51d0c3..ac2414b 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -20,9 +20,11 @@ import ( "net/url" "strings" "time" + + jsoniter "github.com/json-iterator/go" ) -var jsonI = jsonitor.ConfigCompatibleWithStandardLibrary +var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary const defaultSlowThreshold = time.Millisecond * 500 @@ -101,6 +103,14 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { } func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return tc.ExecContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + return tc.execCtx(ctx, query, args) +} + +func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -112,7 +122,7 @@ func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, erro } query = prepared } - result, err := tc.taosQuery(context.TODO(), query, 512) + result, err := tc.taosQuery(ctx, query, 512) if err != nil { return nil, err } @@ -128,7 +138,7 @@ func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error return nil, driver.ErrSkip } // try client-side prepare to reduce round trip - prepared, err := common.InterpolateParams(query, args) + prepared, err := common.InterpolateParams(query, common.ValueArgsToNamedValueArgs(args)) if err != nil { return nil, err } @@ -148,6 +158,36 @@ func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error return rs, err } +func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + return tc.queryCtx(ctx, query, args) +} + +func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if len(args) != 0 { + if !tc.cfg.interpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce round trip + prepared, err := common.InterpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + result, err := tc.taosQuery(ctx, query, tc.readBufferSize) + if err != nil { + return nil, err + } + if result == nil { + return nil, errors.New("wrong result") + } + // Read Result + rs := &rows{ + result: result, + } + return rs, err +} + func (tc *taosConn) Ping(ctx context.Context) (err error) { return nil } @@ -207,24 +247,22 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) ( return data, nil } -const HTTPDTimeFormat = "2006-01-02T15:04:05.999999999-0700" - func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, error) { var result common.TDEngineRestfulResp iter := jsonI.BorrowIterator(make([]byte, bufferSize)) defer jsonI.ReturnIterator(iter) iter.Reset(body) timeFormat := time.RFC3339Nano - iter.ReadObjectCB(func(iter *jsonitor.Iterator, s string) bool { + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { switch s { case "code": result.Code = iter.ReadInt() case "desc": result.Desc = iter.ReadString() case "column_meta": - iter.ReadArrayCB(func(iter *jsonitor.Iterator) bool { + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { index := 0 - iter.ReadArrayCB(func(iter *jsonitor.Iterator) bool { + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { switch index { case 0: result.ColNames = append(result.ColNames, iter.ReadString()) @@ -249,10 +287,10 @@ func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, e case "data": columnCount := len(result.ColTypes) column := 0 - iter.ReadArrayCB(func(iter *jsonitor.Iterator) bool { + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { column = 0 var row = make([]driver.Value, columnCount) - iter.ReadArrayCB(func(iter *jsonitor.Iterator) bool { + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { defer func() { column += 1 }() diff --git a/taosRestful/connector.go b/taosRestful/connector.go index 03477dd..3ae7396 100644 --- a/taosRestful/connector.go +++ b/taosRestful/connector.go @@ -35,7 +35,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Driver implements driver.Connector interface. -// Driver returns &tdengineDriver{}. +// Driver returns &TDengineDriver{}. func (c *connector) Driver() driver.Driver { - return &tdengineDriver{} + return &TDengineDriver{} } diff --git a/taosRestful/driver.go b/taosRestful/driver.go index b7a223e..54313f3 100644 --- a/taosRestful/driver.go +++ b/taosRestful/driver.go @@ -6,13 +6,13 @@ import ( "database/sql/driver" ) -// tdengineDriver is exported to make the driver directly accessible. +// TDengineDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. -type tdengineDriver struct{} +type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted -func (d tdengineDriver) Open(dsn string) (driver.Conn, error) { +func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { cfg, err := parseDSN(dsn) if err != nil { return nil, err @@ -24,5 +24,5 @@ func (d tdengineDriver) Open(dsn string) (driver.Conn, error) { } func init() { - sql.Register("taosRestful", &tdengineDriver{}) + sql.Register("taosRestful", &TDengineDriver{}) } diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index f5e7cfc..77d6afa 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -84,9 +84,8 @@ func (dbt *DBTest) InsertInto(numOfSubTab, numOfItems int) { } type TestResult struct { - ts string - value bool - degress int + ts string + value bool } func runTests(t *testing.T, tests ...func(dbt *DBTest)) { diff --git a/taosSql/connection.go b/taosSql/connection.go index 0fc22c0..086ee09 100644 --- a/taosSql/connection.go +++ b/taosSql/connection.go @@ -74,9 +74,24 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { } func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return tc.ExecContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Result, err error) { if tc.taos == nil { return nil, driver.ErrBadConn } + + return tc.execCtx(ctx, query, args) +} + +func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + var reqIDValue int64 + reqID := ctx.Value(common.ReqIDKey) + if reqID != nil { + reqIDValue, _ = reqID.(int64) + } + if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -88,9 +103,13 @@ func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, erro } query = prepared } - handler := asyncHandlerPool.Get() - defer asyncHandlerPool.Put(handler) - result := tc.taosQuery(query, handler) + h := asyncHandlerPool.Get() + defer asyncHandlerPool.Put(h) + result := tc.taosQuery(query, h, reqIDValue) + return tc.processExecResult(result) +} + +func (tc *taosConn) processExecResult(result *handler.AsyncResult) (driver.Result, error) { defer func() { if result != nil && result.Res != nil { locker.Lock() @@ -106,13 +125,25 @@ func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, erro } affectRows := wrapper.TaosAffectedRows(res) return driver.RowsAffected(affectRows), nil - } func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { if tc.taos == nil { return nil, driver.ErrBadConn } + return tc.queryCtx(ctx, query, args) +} + +func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + var reqIDValue int64 + reqID := ctx.Value(common.ReqIDKey) + if reqID != nil { + reqIDValue, _ = reqID.(int64) + } if len(args) != 0 { if !tc.cfg.interpolateParams { return nil, driver.ErrSkip @@ -124,12 +155,16 @@ func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error } query = prepared } - handler := asyncHandlerPool.Get() - result := tc.taosQuery(query, handler) + h := asyncHandlerPool.Get() + result := tc.taosQuery(query, h, reqIDValue) + return tc.processRows(result, h) +} + +func (tc *taosConn) processRows(result *handler.AsyncResult, h *handler.Handler) (driver.Rows, error) { res := result.Res code := wrapper.TaosError(res) if code != int(errors.SUCCESS) { - asyncHandlerPool.Put(handler) + asyncHandlerPool.Put(h) errStr := wrapper.TaosErrorStr(res) locker.Lock() wrapper.TaosFreeResult(result.Res) @@ -139,11 +174,12 @@ func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error numFields := wrapper.TaosNumFields(res) rowsHeader, err := wrapper.ReadColumn(res, numFields) if err != nil { + asyncHandlerPool.Put(h) return nil, err } precision := wrapper.TaosResultPrecision(res) rs := &rows{ - handler: handler, + handler: h, rowsHeader: rowsHeader, result: res, precision: precision, @@ -164,9 +200,13 @@ func (tc *taosConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver. return nil, &errors.TaosError{Code: 0xffff, ErrStr: "taosSql does not support transaction"} } -func (tc *taosConn) taosQuery(sqlStr string, handler *handler.Handler) *handler.AsyncResult { +func (tc *taosConn) taosQuery(sqlStr string, handler *handler.Handler, reqID int64) *handler.AsyncResult { locker.Lock() - wrapper.TaosQueryA(tc.taos, sqlStr, handler.Handler) + if reqID == 0 { + wrapper.TaosQueryA(tc.taos, sqlStr, handler.Handler) + } else { + wrapper.TaosQueryAWithReqID(tc.taos, sqlStr, handler.Handler, reqID) + } locker.Unlock() r := <-handler.Caller.QueryResult return r diff --git a/taosSql/connection_test.go b/taosSql/connection_test.go new file mode 100644 index 0000000..1d56107 --- /dev/null +++ b/taosSql/connection_test.go @@ -0,0 +1,50 @@ +package taosSql + +import ( + "context" + "database/sql" + "testing" + + "github.com/taosdata/driver-go/v3/common" +) + +func TestTaosConn_ExecContext(t *testing.T) { + ctx := context.WithValue(context.Background(), common.ReqIDKey, common.GetReqID()) + db, err := sql.Open("taosSql", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + defer func() { + _, err = db.ExecContext(ctx, "drop database if exists test_connection") + }() + _, err = db.ExecContext(ctx, "create database if not exists test_connection") + if err != nil { + t.Fatal(err) + } + _, err = db.ExecContext(ctx, "use test_connection") + if err != nil { + t.Fatal(err) + } + _, err = db.ExecContext(ctx, "create stable if not exists meters (ts timestamp, current float, voltage int, phase float) tags (location binary(64), groupId int)") + if err != nil { + t.Fatal(err) + } + _, err = db.ExecContext(ctx, "INSERT INTO d21001 USING meters TAGS ('California.SanFrancisco', 2) VALUES ('?', ?, ?, ?)", "2021-07-13 14:06:32.272", 10.2, 219, 0.32) + if err != nil { + t.Fatal(err) + } + rs, err := db.QueryContext(ctx, "select count(*) from meters") + if err != nil { + t.Fatal(err) + } + defer rs.Close() + rs.Next() + var count int64 + if err = rs.Scan(&count); err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatal("result miss") + } +} diff --git a/taosSql/connector.go b/taosSql/connector.go index 3930359..215b21e 100644 --- a/taosSql/connector.go +++ b/taosSql/connector.go @@ -60,7 +60,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } // Driver implements driver.Connector interface. -// Driver returns &tdengineDriver{}. +// Driver returns &TDengineDriver{}. func (c *connector) Driver() driver.Driver { - return &tdengineDriver{} + return &TDengineDriver{} } diff --git a/taosSql/driver.go b/taosSql/driver.go index 0367fc1..f50a907 100644 --- a/taosSql/driver.go +++ b/taosSql/driver.go @@ -16,13 +16,13 @@ var onceInitLock = sync.Once{} var asyncHandlerPool *handler.HandlerPool var onceInitHandlerPool = sync.Once{} -// tdengineDriver is exported to make the driver directly accessible. +// TDengineDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. -type tdengineDriver struct{} +type TDengineDriver struct{} // Open new Connection. // the DSN string is formatted -func (d tdengineDriver) Open(dsn string) (driver.Conn, error) { +func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { cfg, err := parseDSN(dsn) if err != nil { return nil, err @@ -48,5 +48,5 @@ func (d tdengineDriver) Open(dsn string) (driver.Conn, error) { } func init() { - sql.Register("taosSql", &tdengineDriver{}) + sql.Register("taosSql", &TDengineDriver{}) } diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index 8886c39..9dfc33c 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -69,9 +69,8 @@ func (dbt *DBTest) InsertInto(numOfSubTab, numOfItems int) { } type TestResult struct { - ts string - value bool - degress int + ts string + value bool } func runTests(t *testing.T, tests ...func(dbt *DBTest)) { @@ -152,6 +151,7 @@ var ( if eErr == userErr && err != nil { return ret } + defer rows.Close() if err != nil { dbt.Errorf("%s is not expected, err: %s", query, err.Error()) return ret @@ -165,7 +165,6 @@ var ( } count = count + 1 } - rows.Close() ret = count if expected != -1 && count != expected { dbt.Errorf("%s is not expected, err: %s", query, errors.New("result is not expected")) @@ -206,7 +205,7 @@ func TestAny(t *testing.T) { tests = append(tests, &Obj{fmt.Sprintf("select first(*) from %s.t%d", dbName, 0), nil, false, fp, int64(1)}) tests = append(tests, - &Obj{fmt.Sprintf("select error"), userErr, false, fp, int64(1)}) + &Obj{"select error", userErr, false, fp, int64(1)}) tests = append(tests, &Obj{fmt.Sprintf("select * from %s.t%d", dbName, 0), nil, false, fp, int64(-1)}) tests = append(tests, @@ -292,10 +291,10 @@ func TestStmt(t *testing.T) { if err != nil { dbt.fail("prepare", "prepare", err) } + defer stmt.Close() now := time.Now() stmt.Exec(now.UnixNano()/int64(time.Millisecond), false) stmt.Exec(now.UnixNano()/int64(time.Millisecond)+int64(1), false) - stmt.Close() }) } diff --git a/taosSql/rows.go b/taosSql/rows.go index 94a02b0..685f66f 100644 --- a/taosSql/rows.go +++ b/taosSql/rows.go @@ -6,6 +6,7 @@ import ( "reflect" "unsafe" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" "github.com/taosdata/driver-go/v3/wrapper/handler" @@ -75,7 +76,7 @@ func (rs *rows) Next(dest []driver.Value) error { rs.block = nil return io.EOF } - wrapper.ReadRow(dest, rs.block, rs.blockSize, rs.blockOffset, rs.rowsHeader.ColTypes, rs.precision) + parser.ReadRow(dest, rs.block, rs.blockSize, rs.blockOffset, rs.rowsHeader.ColTypes, rs.precision) rs.blockOffset++ return nil } @@ -108,7 +109,10 @@ func (rs *rows) asyncFetchRows() *handler.AsyncResult { } func (rs *rows) freeResult() { - asyncHandlerPool.Put(rs.handler) + if rs.handler != nil { + asyncHandlerPool.Put(rs.handler) + rs.handler = nil + } if rs.result != nil { locker.Lock() wrapper.TaosFreeResult(rs.result) diff --git a/taosSql/statement.go b/taosSql/statement.go index 63336cf..f9fea69 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -24,7 +24,7 @@ type Stmt struct { pSql string isInsert bool cols []*wrapper.StmtField - tags []*wrapper.StmtField + //tags []*wrapper.StmtField } func (stmt *Stmt) Close() error { @@ -121,6 +121,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { errStr := wrapper.TaosStmtErrStr(stmt.stmt) return errors.NewError(code, errStr) } + defer wrapper.TaosStmtReclaimFields(stmt.stmt, fieldsP) stmt.cols = wrapper.StmtParseFields(num, fieldsP) } if v.Ordinal > len(stmt.cols) { @@ -156,7 +157,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosTinyint(1) } else { v.Value = types.TaosTinyint(0) @@ -180,7 +181,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosSmallint(1) } else { v.Value = types.TaosSmallint(0) @@ -204,7 +205,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosInt(1) } else { v.Value = types.TaosInt(0) @@ -228,7 +229,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosBigint(1) } else { v.Value = types.TaosBigint(0) @@ -252,7 +253,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosFloat(1) } else { v.Value = types.TaosFloat(0) @@ -276,7 +277,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosDouble(1) } else { v.Value = types.TaosDouble(0) @@ -359,7 +360,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosUTinyint(1) } else { v.Value = types.TaosUTinyint(0) @@ -383,7 +384,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosUSmallint(1) } else { v.Value = types.TaosUSmallint(0) @@ -407,7 +408,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosUInt(1) } else { v.Value = types.TaosUInt(0) @@ -431,7 +432,7 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { rv := reflect.ValueOf(v.Value) switch rv.Kind() { case reflect.Bool: - if rv.Bool() == true { + if rv.Bool() { v.Value = types.TaosUBigint(1) } else { v.Value = types.TaosUBigint(0) diff --git a/taosSql/statement_test.go b/taosSql/statement_test.go index b9337c3..2442ff9 100644 --- a/taosSql/statement_test.go +++ b/taosSql/statement_test.go @@ -1047,6 +1047,10 @@ func TestStmtConvertExec(t *testing.T) { } var data []driver.Value tts, err := rows.ColumnTypes() + if err != nil { + t.Error(err) + return + } typesL := make([]reflect.Type, 1) for i, tp := range tts { st := tp.ScanType() diff --git a/taosWS/connection.go b/taosWS/connection.go new file mode 100644 index 0000000..71697e6 --- /dev/null +++ b/taosWS/connection.go @@ -0,0 +1,347 @@ +package taosWS + +import ( + "bytes" + "context" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "net/url" + "strings" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" + "github.com/taosdata/driver-go/v3/common" + taosErrors "github.com/taosdata/driver-go/v3/errors" +) + +var jsonI = jsoniter.ConfigCompatibleWithStandardLibrary + +const ( + WSConnect = "conn" + WSQuery = "query" + WSFetch = "fetch" + WSFetchBlock = "fetch_block" + WSFreeResult = "free_result" +) + +var ( + NotQueryError = errors.New("sql is an update statement not a query statement") + ReadTimeoutError = errors.New("read timeout") +) + +type taosConn struct { + buf *bytes.Buffer + client *websocket.Conn + requestID uint64 + readTimeout time.Duration + writeTimeout time.Duration + cfg *config + endpoint string +} + +func (tc *taosConn) generateReqID() uint64 { + return atomic.AddUint64(&tc.requestID, 1) +} + +func newTaosConn(cfg *config) (*taosConn, error) { + endpointUrl := &url.URL{ + Scheme: cfg.net, + Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), + Path: "/rest/ws", + } + if cfg.token != "" { + endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) + } + endpoint := endpointUrl.String() + ws, _, err := common.DefaultDialer.Dial(endpoint, nil) + if err != nil { + return nil, err + } + ws.SetReadLimit(common.BufferSize4M) + ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) + return nil + }) + tc := &taosConn{ + buf: &bytes.Buffer{}, + client: ws, + requestID: 0, + readTimeout: cfg.readTimeout, + writeTimeout: cfg.writeTimeout, + cfg: cfg, + endpoint: endpoint, + } + + err = tc.connect() + if err != nil { + tc.Close() + } + return tc, nil +} + +func (tc *taosConn) Begin() (driver.Tx, error) { + return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support transaction"} +} + +func (tc *taosConn) Close() (err error) { + if tc.client != nil { + err = tc.client.Close() + } + tc.client = nil + tc.cfg = nil + tc.endpoint = "" + return err +} + +func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { + return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support stmt"} +} + +func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + return tc.execCtx(ctx, query, args) +} + +func (tc *taosConn) execCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if len(args) != 0 { + if !tc.cfg.interpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra round trips for preparing and closing a statement + prepared, err := common.InterpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + reqID := tc.generateReqID() + req := &WSQueryReq{ + ReqID: reqID, + SQL: query, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: WSQuery, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp WSQueryResp + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return driver.RowsAffected(resp.AffectedRows), nil +} + +func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + return tc.queryCtx(ctx, query, args) +} + +func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if len(args) != 0 { + if !tc.cfg.interpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce round trip + prepared, err := common.InterpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + reqID := tc.generateReqID() + req := &WSQueryReq{ + ReqID: reqID, + SQL: query, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: WSQuery, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp WSQueryResp + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + if resp.IsUpdate { + return nil, NotQueryError + } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + } + return rs, err +} + +func (tc *taosConn) Ping(ctx context.Context) (err error) { + return nil +} + +func (tc *taosConn) connect() error { + req := &WSConnectReq{ + ReqID: 0, + User: tc.cfg.user, + Password: tc.cfg.passwd, + DB: tc.cfg.dbName, + } + args, err := jsonI.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: WSConnect, + Args: args, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + var resp WSConnectResp + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (tc *taosConn) writeText(data []byte) error { + tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) + err := tc.client.WriteMessage(websocket.TextMessage, data) + if err != nil { + return NewBadConnErrorWithCtx(err, string(data)) + } + return nil +} + +func (tc *taosConn) readTo(to interface{}) error { + var outErr error + done := make(chan struct{}) + go func() { + defer func() { + close(done) + }() + mt, respBytes, err := tc.client.ReadMessage() + if err != nil { + outErr = NewBadConnError(err) + return + } + if mt != websocket.TextMessage { + outErr = NewBadConnErrorWithCtx(fmt.Errorf("readTo: got wrong message type %d", mt), formatBytes(respBytes)) + return + } + err = jsonI.Unmarshal(respBytes, to) + if err != nil { + outErr = NewBadConnErrorWithCtx(err, string(respBytes)) + return + } + }() + ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) + defer cancel() + select { + case <-done: + return outErr + case <-ctx.Done(): + return NewBadConnError(ReadTimeoutError) + } +} + +func (tc *taosConn) readBytes() ([]byte, error) { + var respBytes []byte + var outErr error + done := make(chan struct{}) + go func() { + defer func() { + close(done) + }() + mt, message, err := tc.client.ReadMessage() + if err != nil { + outErr = NewBadConnError(err) + return + } + if mt != websocket.BinaryMessage { + outErr = NewBadConnErrorWithCtx(fmt.Errorf("readBytes: got wrong message type %d", mt), string(respBytes)) + return + } + respBytes = message + }() + ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) + defer cancel() + select { + case <-done: + return respBytes, outErr + case <-ctx.Done(): + return nil, NewBadConnError(ReadTimeoutError) + } +} + +func formatBytes(bs []byte) string { + if len(bs) == 0 { + return "" + } + buffer := &strings.Builder{} + buffer.WriteByte('[') + for i := 0; i < len(bs); i++ { + fmt.Fprintf(buffer, "0x%02x", bs[i]) + if i != len(bs)-1 { + buffer.WriteByte(',') + } + } + buffer.WriteByte(']') + return buffer.String() +} diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go new file mode 100644 index 0000000..200aefe --- /dev/null +++ b/taosWS/connection_test.go @@ -0,0 +1,45 @@ +package taosWS + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_formatBytes(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + want string + }{ + { + name: "nothing", + args: args{ + bs: nil, + }, + want: "", + }, + { + name: "one byte", + args: args{ + bs: []byte{'a'}, + }, + want: "[0x61]", + }, + { + name: "two byes", + args: args{ + bs: []byte{'a', 'b'}, + }, + want: "[0x61,0x62]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, formatBytes(tt.args.bs), "formatBytes(%v)", tt.args.bs) + }) + } +} diff --git a/taosWS/connector.go b/taosWS/connector.go new file mode 100644 index 0000000..8259d09 --- /dev/null +++ b/taosWS/connector.go @@ -0,0 +1,47 @@ +package taosWS + +import ( + "context" + "database/sql/driver" + + "github.com/taosdata/driver-go/v3/common" +) + +type connector struct { + cfg *config +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + // Connect to Server + if len(c.cfg.user) == 0 { + c.cfg.user = common.DefaultUser + } + if len(c.cfg.passwd) == 0 { + c.cfg.passwd = common.DefaultPassword + } + if c.cfg.port == 0 { + c.cfg.port = common.DefaultHttpPort + } + if len(c.cfg.net) == 0 { + c.cfg.net = "ws" + } + if len(c.cfg.addr) == 0 { + c.cfg.addr = "127.0.0.1" + } + if c.cfg.readTimeout == 0 { + c.cfg.readTimeout = common.DefaultMessageTimeout + } + if c.cfg.writeTimeout == 0 { + c.cfg.writeTimeout = common.DefaultWriteWait + } + tc, err := newTaosConn(c.cfg) + return tc, err +} + +// Driver implements driver.Connector interface. +// Driver returns &TDengineDriver{}. +func (c *connector) Driver() driver.Driver { + return &TDengineDriver{} +} diff --git a/taosWS/connector_test.go b/taosWS/connector_test.go new file mode 100644 index 0000000..9ea2176 --- /dev/null +++ b/taosWS/connector_test.go @@ -0,0 +1,701 @@ +package taosWS + +import ( + "database/sql" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/types" +) + +func TestAllTypeQuery(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + db, err := sql.Open("taosWS", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec("drop database if exists ws_test") + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec("create database if not exists ws_test") + if err != nil { + t.Fatal(err) + } + var ( + v1 = true + v2 = int8(rand.Int()) + v3 = int16(rand.Int()) + v4 = rand.Int31() + v5 = int64(rand.Int31()) + v6 = uint8(rand.Uint32()) + v7 = uint16(rand.Uint32()) + v8 = rand.Uint32() + v9 = uint64(rand.Uint32()) + v10 = rand.Float32() + v11 = rand.Float64() + v12 = "test_binary" + v13 = "test_nchar" + ) + + _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(t json)", + ) + if err != nil { + t.Fatal(err) + } + now := time.Now().Round(time.Millisecond) + _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + for rows.Next() { + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + tt types.RawMessage + ) + err := rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + &tt, + ) + assert.Equal(t, now.UTC(), ts.UTC()) + assert.Equal(t, v1, c1) + assert.Equal(t, v2, c2) + assert.Equal(t, v3, c3) + assert.Equal(t, v4, c4) + assert.Equal(t, v5, c5) + assert.Equal(t, v6, c6) + assert.Equal(t, v7, c7) + assert.Equal(t, v8, c8) + assert.Equal(t, v9, c9) + assert.Equal(t, v10, c10) + assert.Equal(t, v11, c11) + assert.Equal(t, v12, c12) + assert.Equal(t, v13, c13) + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) + if err != nil { + t.Fatal(err) + } + if ts.IsZero() { + t.Fatal(ts) + } + + } +} + +func TestAllTypeQueryNull(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + db, err := sql.Open("taosWS", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec("drop database if exists ws_test_null") + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec("create database if not exists ws_test_null") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("create table if not exists ws_test_null.alltype(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(t json)", + ) + if err != nil { + t.Fatal(err) + } + now := time.Now().Round(time.Millisecond) + _, err = db.Exec(fmt.Sprintf(`insert into ws_test_null.t1 using ws_test_null.alltype tags('null') values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from ws_test_null.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + for rows.Next() { + var ( + ts time.Time + c1 *bool + c2 *int8 + c3 *int16 + c4 *int32 + c5 *int64 + c6 *uint8 + c7 *uint16 + c8 *uint32 + c9 *uint64 + c10 *float32 + c11 *float64 + c12 *string + c13 *string + tt *string + ) + err := rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + &tt, + ) + assert.Equal(t, now.UTC(), ts.UTC()) + assert.Nil(t, c1) + assert.Nil(t, c2) + assert.Nil(t, c3) + assert.Nil(t, c4) + assert.Nil(t, c5) + assert.Nil(t, c6) + assert.Nil(t, c7) + assert.Nil(t, c8) + assert.Nil(t, c9) + assert.Nil(t, c10) + assert.Nil(t, c11) + assert.Nil(t, c12) + assert.Nil(t, c13) + assert.Nil(t, tt) + if err != nil { + + t.Fatal(err) + } + if ts.IsZero() { + t.Fatal(ts) + } + + } +} + +func TestAllTypeQueryCompression(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + db, err := sql.Open("taosWS", dataSourceNameWithCompression) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec("drop database if exists ws_test") + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec("create database if not exists ws_test") + if err != nil { + t.Fatal(err) + } + var ( + v1 = true + v2 = int8(rand.Int()) + v3 = int16(rand.Int()) + v4 = rand.Int31() + v5 = int64(rand.Int31()) + v6 = uint8(rand.Uint32()) + v7 = uint16(rand.Uint32()) + v8 = rand.Uint32() + v9 = uint64(rand.Uint32()) + v10 = rand.Float32() + v11 = rand.Float64() + v12 = "test_binary" + v13 = "test_nchar" + ) + + _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(t json)", + ) + if err != nil { + t.Fatal(err) + } + now := time.Now().Round(time.Millisecond) + _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + for rows.Next() { + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + tt types.RawMessage + ) + err := rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + &tt, + ) + assert.Equal(t, now.UTC(), ts.UTC()) + assert.Equal(t, v1, c1) + assert.Equal(t, v2, c2) + assert.Equal(t, v3, c3) + assert.Equal(t, v4, c4) + assert.Equal(t, v5, c5) + assert.Equal(t, v6, c6) + assert.Equal(t, v7, c7) + assert.Equal(t, v8, c8) + assert.Equal(t, v9, c9) + assert.Equal(t, v10, c10) + assert.Equal(t, v11, c11) + assert.Equal(t, v12, c12) + assert.Equal(t, v13, c13) + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) + if err != nil { + t.Fatal(err) + } + if ts.IsZero() { + t.Fatal(ts) + } + } +} + +func TestAllTypeQueryWithoutJson(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + db, err := sql.Open("taosWS", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec("drop database if exists ws_test_without_json") + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec("create database if not exists ws_test_without_json") + if err != nil { + t.Fatal(err) + } + var ( + v1 = false + v2 = int8(rand.Int()) + v3 = int16(rand.Int()) + v4 = rand.Int31() + v5 = int64(rand.Int31()) + v6 = uint8(rand.Uint32()) + v7 = uint16(rand.Uint32()) + v8 = rand.Uint32() + v9 = uint64(rand.Uint32()) + v10 = rand.Float32() + v11 = rand.Float64() + v12 = "test_binary" + v13 = "test_nchar" + ) + + _, err = db.Exec("create table if not exists ws_test_without_json.all_type(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")", + ) + if err != nil { + t.Fatal(err) + } + now := time.Now().Round(time.Millisecond) + _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json.all_type values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + for rows.Next() { + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err := rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + ) + assert.Equal(t, now.UTC(), ts.UTC()) + assert.Equal(t, v1, c1) + assert.Equal(t, v2, c2) + assert.Equal(t, v3, c3) + assert.Equal(t, v4, c4) + assert.Equal(t, v5, c5) + assert.Equal(t, v6, c6) + assert.Equal(t, v7, c7) + assert.Equal(t, v8, c8) + assert.Equal(t, v9, c9) + assert.Equal(t, v10, c10) + assert.Equal(t, v11, c11) + assert.Equal(t, v12, c12) + assert.Equal(t, v13, c13) + if err != nil { + t.Fatal(err) + } + if ts.IsZero() { + t.Fatal(ts) + } + + } +} + +func TestAllTypeQueryNullWithoutJson(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + db, err := sql.Open("taosWS", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec("drop database if exists ws_test_without_json_null") + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec("create database if not exists ws_test_without_json_null") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("create table if not exists ws_test_without_json_null.all_type(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")", + ) + if err != nil { + t.Fatal(err) + } + now := time.Now().Round(time.Millisecond) + _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json_null.all_type values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json_null.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + for rows.Next() { + var ( + ts time.Time + c1 *bool + c2 *int8 + c3 *int16 + c4 *int32 + c5 *int64 + c6 *uint8 + c7 *uint16 + c8 *uint32 + c9 *uint64 + c10 *float32 + c11 *float64 + c12 *string + c13 *string + ) + err := rows.Scan( + &ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13, + ) + assert.Equal(t, now.UTC(), ts.UTC()) + assert.Nil(t, c1) + assert.Nil(t, c2) + assert.Nil(t, c3) + assert.Nil(t, c4) + assert.Nil(t, c5) + assert.Nil(t, c6) + assert.Nil(t, c7) + assert.Nil(t, c8) + assert.Nil(t, c9) + assert.Nil(t, c10) + assert.Nil(t, c11) + assert.Nil(t, c12) + assert.Nil(t, c13) + if err != nil { + + t.Fatal(err) + } + if ts.IsZero() { + t.Fatal(ts) + } + + } +} + +func TestBatch(t *testing.T) { + now := time.Now() + tests := []struct { + name string + sql string + isQuery bool + }{ + { + name: "drop db", + sql: "drop database if exists test_batch", + }, + { + name: "create db", + sql: "create database test_batch", + }, + { + name: "use db", + sql: "use test_batch", + }, + { + name: "create table", + sql: "create table test(ts timestamp,v int)", + }, + { + name: "insert 1", + sql: fmt.Sprintf("insert into test values ('%s',1)", now.Format(time.RFC3339Nano)), + }, + { + name: "insert 2", + sql: fmt.Sprintf("insert into test values ('%s',2)", now.Add(time.Second).Format(time.RFC3339Nano)), + }, + { + name: "query all", + sql: "select * from test order by ts", + isQuery: true, + }, + { + name: "drop database", + sql: "drop database if exists test_batch", + }, + } + db, err := sql.Open("taosWS", dataSourceName) + if err != nil { + t.Fatal(err) + } + defer db.Close() + //err = db.Ping() + //if err != nil { + // t.Fatal(err) + //} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.isQuery { + result, err := db.Query(tt.sql) + assert.NoError(t, err) + var check [][]interface{} + for result.Next() { + var ts time.Time + var v int + err := result.Scan(&ts, &v) + assert.NoError(t, err) + check = append(check, []interface{}{ts, v}) + } + assert.Equal(t, 2, len(check)) + assert.Equal(t, now.UnixNano()/1e6, check[0][0].(time.Time).UnixNano()/1e6) + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, check[1][0].(time.Time).UnixNano()/1e6) + assert.Equal(t, int(1), check[0][1].(int)) + assert.Equal(t, int(2), check[1][1].(int)) + } else { + _, err := db.Exec(tt.sql) + assert.NoError(t, err) + } + }) + } +} diff --git a/taosWS/driver.go b/taosWS/driver.go new file mode 100644 index 0000000..ebe457c --- /dev/null +++ b/taosWS/driver.go @@ -0,0 +1,28 @@ +package taosWS + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +// TDengineDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type TDengineDriver struct{} + +// Open new Connection. +// the DSN string is formatted +func (d TDengineDriver) Open(dsn string) (driver.Conn, error) { + cfg, err := parseDSN(dsn) + if err != nil { + return nil, err + } + c := &connector{ + cfg: cfg, + } + return c.Connect(context.Background()) +} + +func init() { + sql.Register("taosWS", &TDengineDriver{}) +} diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go new file mode 100644 index 0000000..da3543d --- /dev/null +++ b/taosWS/driver_test.go @@ -0,0 +1,269 @@ +package taosWS + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "log" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// Ensure that all the driver interfaces are implemented +func TestMain(m *testing.M) { + m.Run() + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + log.Fatalf("error on: sql.open %s", err.Error()) + } + defer db.Close() + defer func() { + db.Exec(fmt.Sprintf("drop database if exists %s", dbName)) + }() +} + +var ( + driverName = "taosWS" + user = "root" + password = "taosdata" + host = "127.0.0.1" + port = 6041 + dbName = "test_taos_ws" + dataSourceName = fmt.Sprintf("%s:%s@ws(%s:%d)/", user, password, host, port) + dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?disableCompression=false", user, password, host, port) +) + +type DBTest struct { + *testing.T + *sql.DB +} + +func NewDBTest(t *testing.T) (dbt *DBTest) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + dbt = &DBTest{t, db} + return +} + +func (dbt *DBTest) CreateTables(numOfSubTab int) { + _, err := dbt.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)) + if err != nil { + dbt.Fatalf("create tables error %s", err) + } + _, err = dbt.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", dbName)) + if err != nil { + dbt.Fatalf("create tables error %s", err) + } + _, err = dbt.mustExec(fmt.Sprintf("drop table if exists %s.super", dbName)) + if err != nil { + dbt.Fatalf("create tables error %s", err) + } + _, err = dbt.mustExec(fmt.Sprintf("CREATE TABLE %s.super (ts timestamp, v BOOL) tags (degress int)", dbName)) + if err != nil { + dbt.Fatalf("create tables error %s", err) + } + for i := 0; i < numOfSubTab; i++ { + _, err := dbt.mustExec(fmt.Sprintf("create table %s.t%d using %s.super tags(%d)", dbName, i%10, dbName, i)) + if err != nil { + dbt.Fatalf("create tables error %s", err) + } + } +} +func (dbt *DBTest) InsertInto(numOfSubTab, numOfItems int) { + now := time.Now() + t := now.Add(-100 * time.Minute) + for i := 0; i < numOfItems; i++ { + dbt.mustExec(fmt.Sprintf("insert into %s.t%d values(%d, %t)", dbName, i%numOfSubTab, t.UnixNano()/int64(time.Millisecond)+int64(i), i%2 == 0)) + } +} + +type TestResult struct { + ts string + value bool +} + +func runTests(t *testing.T, tests ...func(dbt *DBTest)) { + dbt := NewDBTest(t) + // prepare data + dbt.Exec("DROP TABLE IF EXISTS test_taos_ws.test") + var numOfSubTables = 10 + var numOfItems = 200 + dbt.CreateTables(numOfSubTables) + dbt.InsertInto(numOfSubTables, numOfItems) + for _, test := range tests { + test(dbt) + dbt.Exec("DROP TABLE IF EXISTS test_taos_ws.test") + } +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result, err error) { + res, err = dbt.Exec(query, args...) + return +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows, err error) { + rows, err = dbt.Query(query, args...) + return +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, func(dbt *DBTest) { + // just a comment, no query + _, err := dbt.mustExec("") + if err == nil { + dbt.Fatalf("error is expected") + } + + }) +} + +func TestErrorQuery(t *testing.T) { + runTests(t, func(dbt *DBTest) { + // just a comment, no query + _, err := dbt.mustExec("xxxxxxx inot") + if err == nil { + dbt.Fatalf("error is expected") + } + }) +} + +type ( + execFunc func(dbt *DBTest, query string, exec bool, err error, expected int64) int64 +) + +type Obj struct { + query string + err error + exec bool + fp execFunc + expect int64 +} + +var ( + errUser = errors.New("user error") + fp = func(dbt *DBTest, query string, exec bool, eErr error, expected int64) int64 { + var ret int64 = 0 + if exec == false { + rows, err := dbt.mustQuery(query) + if eErr == errUser && err != nil { + return ret + } + if err != nil { + dbt.Errorf("%s is not expected, err: %s", query, err.Error()) + return ret + } else { + var count int64 = 0 + for rows.Next() { + var row TestResult + if err := rows.Scan(&(row.ts), &(row.value)); err != nil { + dbt.Error(err.Error()) + return ret + } + count = count + 1 + } + rows.Close() + ret = count + if expected != -1 && count != expected { + dbt.Errorf("%s is not expected, err: %s", query, errors.New("result is not expected")) + } + } + } else { + res, err := dbt.mustExec(query) + if err != eErr { + dbt.Fatalf("%s is not expected, err: %s", query, err.Error()) + } else { + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("%s is not expected, err: %s", query, err.Error()) + } + if expected != -1 && count != expected { + dbt.Fatalf("%s is not expected , err: %s", query, errors.New("result is not expected")) + } + } + } + return ret + } +) + +func TestAny(t *testing.T) { + runTests(t, func(dbt *DBTest) { + now := time.Now() + tests := make([]*Obj, 0, 100) + tests = append(tests, + &Obj{fmt.Sprintf("insert into %s.t%d values(%d, %t)", dbName, 0, now.UnixNano()/int64(time.Millisecond)-1, false), nil, true, fp, int64(1)}) + tests = append(tests, + &Obj{fmt.Sprintf("insert into %s.t%d values(%d, %t)", dbName, 0, now.UnixNano()/int64(time.Millisecond)-1, false), nil, true, fp, int64(1)}) + tests = append(tests, + &Obj{fmt.Sprintf("select first(*) from %s.t%d", dbName, 0), nil, false, fp, int64(1)}) + tests = append(tests, + &Obj{"select error", errUser, false, fp, int64(1)}) + tests = append(tests, + &Obj{fmt.Sprintf("select * from %s.t%d", dbName, 0), nil, false, fp, int64(-1)}) + tests = append(tests, + &Obj{fmt.Sprintf("select * from %s.t%d", dbName, 0), nil, false, fp, int64(-1)}) + + for _, obj := range tests { + fp = obj.fp + fp(dbt, obj.query, obj.exec, obj.err, obj.expect) + } + }) +} + +func TestChinese(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Fatalf("error on: sql.open %s", err.Error()) + return + } + defer db.Close() + defer func() { + _, err = db.Exec("drop database if exists test_chinese") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database if not exists test_chinese") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("drop table if exists test_chinese.chinese") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_chinese.chinese(ts timestamp,v nchar(32))") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(`INSERT INTO test_chinese.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") + if err != nil { + t.Error(err) + return + } + rows, err := db.Query("select * from test_chinese.chinese") + if err != nil { + t.Error(err) + return + } + counter := 0 + for rows.Next() { + counter += 1 + row := make([]driver.Value, 2) + err := rows.Scan(&row[0], &row[1]) + if err != nil { + t.Error(err) + return + } + t.Log(row) + } + assert.Equal(t, 1, counter) +} diff --git a/taosWS/dsn.go b/taosWS/dsn.go new file mode 100644 index 0000000..ca5ff22 --- /dev/null +++ b/taosWS/dsn.go @@ -0,0 +1,169 @@ +package taosWS + +import ( + "net/url" + "strconv" + "strings" + "time" + + "github.com/taosdata/driver-go/v3/errors" +) + +var ( + errInvalidDSNUnescaped = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: did you forget to escape a param value?"} + errInvalidDSNAddr = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network address not terminated (missing closing brace)"} + errInvalidDSNPort = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: network port is not a valid number"} + errInvalidDSNNoSlash = &errors.TaosError{Code: 0xffff, ErrStr: "invalid DSN: missing the slash separating the database name"} +) + +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. +type config struct { + user string // Username + passwd string // Password (requires User) + net string // Network type + addr string // Network address (requires Net) + port int + dbName string // Database name + params map[string]string // Connection parameters + interpolateParams bool // Interpolate placeholders into query string + token string // cloud platform token + readTimeout time.Duration // read message timeout + writeTimeout time.Duration // write message timeout +} + +// NewConfig creates a new Config and sets default values. +func newConfig() *config { + return &config{ + interpolateParams: true, + } +} + +// ParseDSN parses the DSN string to a Config +func parseDSN(dsn string) (cfg *config, err error) { + // New config with some default values + cfg = newConfig() + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.passwd = dsn[k+1 : j] + break + } + } + cfg.user = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + //return nil, errInvalidDSNAddr + } + strList := strings.Split(dsn[k+1:i-1], ":") + if len(strList) == 1 { + return nil, errInvalidDSNAddr + } + if len(strList[0]) != 0 { + cfg.addr = strList[0] + cfg.port, err = strconv.Atoi(strList[1]) + if err != nil { + return nil, errInvalidDSNPort + } + } + break + } + } + cfg.net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.dbName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + // Enable client side placeholder substitution + case "interpolateParams": + cfg.interpolateParams, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} + } + case "token": + cfg.token = value + case "readTimeout": + cfg.readTimeout, err = time.ParseDuration(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} + } + case "writeTimeout": + cfg.writeTimeout, err = time.ParseDuration(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid duration value: " + value} + } + default: + // lazy init + if cfg.params == nil { + cfg.params = make(map[string]string) + } + + if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go new file mode 100644 index 0000000..356be70 --- /dev/null +++ b/taosWS/dsn_test.go @@ -0,0 +1,38 @@ +package taosWS + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseDsn(t *testing.T) { + tests := []struct { + dsn string + errs string + want *config + }{ + {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, + {dsn: "user:passwd@ws(fqdn:6041)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", addr: "fqdn", port: 6041, dbName: "dbname", interpolateParams: true}}, + {dsn: "user:passwd@ws()/dbname", errs: "invalid DSN: network address not terminated (missing closing brace)"}, + {dsn: "user:passwd@ws(:)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, + {dsn: "user:passwd@ws(:0)/dbname", want: &config{user: "user", passwd: "passwd", net: "ws", dbName: "dbname", interpolateParams: true}}, + {dsn: "user:passwd@wss(:0)/", want: &config{user: "user", passwd: "passwd", net: "wss", interpolateParams: true}}, + {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, + {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, + {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, + } + for _, tc := range tests { + t.Run(tc.dsn, func(t *testing.T) { + cfg, err := parseDSN(tc.dsn) + if err != nil { + if errs := err.Error(); errs != tc.errs { + t.Fatal(tc.errs, "\n", errs) + } + return + } + assert.Equal(t, tc.want, cfg) + }) + } +} diff --git a/taosWS/error.go b/taosWS/error.go new file mode 100644 index 0000000..a156d8e --- /dev/null +++ b/taosWS/error.go @@ -0,0 +1,30 @@ +package taosWS + +import ( + "database/sql/driver" + "fmt" +) + +type BadConnError struct { + err error + ctx string +} + +func NewBadConnError(err error) *BadConnError { + return &BadConnError{err: err} +} + +func NewBadConnErrorWithCtx(err error, ctx string) *BadConnError { + return &BadConnError{err: err, ctx: ctx} +} + +func (*BadConnError) Unwrap() error { + return driver.ErrBadConn +} + +func (e *BadConnError) Error() string { + if len(e.ctx) == 0 { + return e.err.Error() + } + return fmt.Sprintf("error %s: context: %s", e.err.Error(), e.ctx) +} diff --git a/taosWS/error_test.go b/taosWS/error_test.go new file mode 100644 index 0000000..c364f6d --- /dev/null +++ b/taosWS/error_test.go @@ -0,0 +1,19 @@ +package taosWS + +import ( + "database/sql/driver" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBadConnError(t *testing.T) { + nothingErr := errors.New("error") + err := NewBadConnError(nothingErr) + assert.ErrorIs(t, err, driver.ErrBadConn) + assert.Equal(t, "error", err.Error()) + err = NewBadConnErrorWithCtx(nothingErr, "nothing") + assert.ErrorIs(t, err, driver.ErrBadConn) + assert.Equal(t, "error error: context: nothing", err.Error()) +} diff --git a/taosWS/proto.go b/taosWS/proto.go new file mode 100644 index 0000000..fd2fb39 --- /dev/null +++ b/taosWS/proto.go @@ -0,0 +1,71 @@ +package taosWS + +import "encoding/json" + +type WSConnectReq struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` +} + +type WSConnectResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +type WSQueryReq struct { + ReqID uint64 `json:"req_id"` + SQL string `json:"sql"` +} + +type WSQueryResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + IsUpdate bool `json:"is_update"` + AffectedRows int `json:"affected_rows"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +type WSFetchReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFetchResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + Completed bool `json:"completed"` + Lengths []int `json:"lengths"` + Rows int `json:"rows"` +} + +type WSFetchBlockReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFreeResultReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSAction struct { + Action string `json:"action"` + Args json.RawMessage `json:"args"` +} diff --git a/taosWS/rows.go b/taosWS/rows.go new file mode 100644 index 0000000..b75f4e2 --- /dev/null +++ b/taosWS/rows.go @@ -0,0 +1,180 @@ +package taosWS + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "io" + "reflect" + "unsafe" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + taosErrors "github.com/taosdata/driver-go/v3/errors" +) + +type rows struct { + buf *bytes.Buffer + blockPtr unsafe.Pointer + blockOffset int + blockSize int + resultID uint64 + block []byte + conn *taosConn + fieldsCount int + fieldsNames []string + fieldsTypes []uint8 + fieldsLengths []int64 + precision int +} + +func (rs *rows) Columns() []string { + return rs.fieldsNames +} + +func (rs *rows) ColumnTypeDatabaseTypeName(i int) string { + return common.TypeNameMap[int(rs.fieldsTypes[i])] +} + +func (rs *rows) ColumnTypeLength(i int) (length int64, ok bool) { + return rs.fieldsLengths[i], ok +} + +func (rs *rows) ColumnTypeScanType(i int) reflect.Type { + t, exist := common.ColumnTypeMap[int(rs.fieldsTypes[i])] + if !exist { + return common.UnknownType + } + return t +} + +func (rs *rows) Close() error { + rs.blockPtr = nil + rs.block = nil + return rs.freeResult() +} + +func (rs *rows) Next(dest []driver.Value) error { + if rs.blockPtr == nil { + err := rs.taosFetchBlock() + if err != nil { + return err + } + } + if rs.blockSize == 0 { + rs.blockPtr = nil + rs.block = nil + return io.EOF + } + if rs.blockOffset >= rs.blockSize { + err := rs.taosFetchBlock() + if err != nil { + return err + } + } + if rs.blockSize == 0 { + rs.blockPtr = nil + rs.block = nil + return io.EOF + } + parser.ReadRow(dest, rs.blockPtr, rs.blockSize, rs.blockOffset, rs.fieldsTypes, rs.precision) + rs.blockOffset += 1 + return nil +} + +func (rs *rows) taosFetchBlock() error { + reqID := rs.conn.generateReqID() + req := &WSFetchReq{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: WSFetch, + Args: args, + } + rs.buf.Reset() + + err = jsonI.NewEncoder(rs.buf).Encode(action) + if err != nil { + return err + } + err = rs.conn.writeText(rs.buf.Bytes()) + if err != nil { + return err + } + var resp WSFetchResp + err = rs.conn.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + if resp.Completed { + rs.blockSize = 0 + return nil + } else { + rs.blockSize = resp.Rows + return rs.fetchBlock() + } +} + +func (rs *rows) fetchBlock() error { + reqID := rs.conn.generateReqID() + req := &WSFetchBlockReq{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: WSFetchBlock, + Args: args, + } + rs.buf.Reset() + err = jsonI.NewEncoder(rs.buf).Encode(action) + if err != nil { + return err + } + err = rs.conn.writeText(rs.buf.Bytes()) + if err != nil { + return err + } + respBytes, err := rs.conn.readBytes() + if err != nil { + return err + } + rs.block = respBytes + rs.blockPtr = unsafe.Pointer(*(*uintptr)(unsafe.Pointer(&rs.block)) + uintptr(16)) + rs.blockOffset = 0 + return nil +} + +func (rs *rows) freeResult() error { + tc := rs.conn + reqID := tc.generateReqID() + req := &WSFreeResultReq{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: WSFreeResult, + Args: args, + } + rs.buf.Reset() + err = jsonI.NewEncoder(rs.buf).Encode(action) + if err != nil { + return err + } + return nil +} diff --git a/wrapper/asynccb.go b/wrapper/asynccb.go index 659b318..14446c5 100644 --- a/wrapper/asynccb.go +++ b/wrapper/asynccb.go @@ -21,18 +21,18 @@ type Caller interface { //export QueryCallback func QueryCallback(p unsafe.Pointer, res *C.TAOS_RES, code C.int) { - caller := cgo.Handle(p).Value().(Caller) + caller := (*(*cgo.Handle)(p)).Value().(Caller) caller.QueryCall(unsafe.Pointer(res), int(code)) } //export FetchRowsCallback func FetchRowsCallback(p unsafe.Pointer, res *C.TAOS_RES, numOfRows C.int) { - caller := cgo.Handle(p).Value().(Caller) + caller := (*(*cgo.Handle)(p)).Value().(Caller) caller.FetchCall(unsafe.Pointer(res), int(numOfRows)) } //export FetchRawBlockCallback func FetchRawBlockCallback(p unsafe.Pointer, res *C.TAOS_RES, numOfRows C.int) { - caller := cgo.Handle(p).Value().(Caller) + caller := (*(*cgo.Handle)(p)).Value().(Caller) caller.FetchCall(unsafe.Pointer(res), int(numOfRows)) } diff --git a/wrapper/block.go b/wrapper/block.go index 210bb91..30ef5a6 100644 --- a/wrapper/block.go +++ b/wrapper/block.go @@ -8,82 +8,9 @@ package wrapper */ import "C" import ( - "database/sql/driver" - "math" "unsafe" - - "github.com/taosdata/driver-go/v3/common" -) - -const ( - Int8Size = unsafe.Sizeof(int8(0)) - Int16Size = unsafe.Sizeof(int16(0)) - Int32Size = unsafe.Sizeof(int32(0)) - Int64Size = unsafe.Sizeof(int64(0)) - UInt8Size = unsafe.Sizeof(uint8(0)) - UInt16Size = unsafe.Sizeof(uint16(0)) - UInt32Size = unsafe.Sizeof(uint32(0)) - UInt64Size = unsafe.Sizeof(uint64(0)) - Float32Size = unsafe.Sizeof(float32(0)) - Float64Size = unsafe.Sizeof(float64(0)) ) -const ( - ColInfoSize = Int8Size + Int32Size - RawBlockVersionOffset = 0 - RawBlockLengthOffset = RawBlockVersionOffset + Int32Size - NumOfRowsOffset = RawBlockLengthOffset + Int32Size - NumOfColsOffset = NumOfRowsOffset + Int32Size - HasColumnSegmentOffset = NumOfColsOffset + Int32Size - GroupIDOffset = HasColumnSegmentOffset + Int32Size - ColInfoOffset = GroupIDOffset + UInt64Size -) - -func RawBlockGetVersion(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockVersionOffset))) -} - -func RawBlockGetLength(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockLengthOffset))) -} - -func RawBlockGetNumOfRows(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfRowsOffset))) -} - -func RawBlockGetNumOfCols(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfColsOffset))) -} - -func RawBlockGetHasColumnSegment(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + HasColumnSegmentOffset))) -} - -func RawBlockGetGroupID(rawBlock unsafe.Pointer) uint64 { - return *((*uint64)(unsafe.Pointer(uintptr(rawBlock) + GroupIDOffset))) -} - -type RawBlockColInfo struct { - ColType int8 - Bytes int32 -} - -func RawBlockGetColInfo(rawBlock unsafe.Pointer, infos []RawBlockColInfo) { - for i := 0; i < len(infos); i++ { - offset := uintptr(rawBlock) + ColInfoOffset + ColInfoSize*uintptr(i) - infos[i].ColType = *((*int8)(unsafe.Pointer(offset))) - infos[i].Bytes = *((*int32)(unsafe.Pointer(offset + Int8Size))) - } -} - -func RawBlockGetColumnLengthOffset(colCount int) uintptr { - return ColInfoOffset + uintptr(colCount)*ColInfoSize -} - -func RawBlockGetColDataOffset(colCount int) uintptr { - return ColInfoOffset + uintptr(colCount)*ColInfoSize + uintptr(colCount)*Int32Size -} - // TaosFetchRawBlock int taos_fetch_raw_block(TAOS_RES *res, int* numOfRows, void** pData); func TaosFetchRawBlock(result unsafe.Pointer) (int, int, unsafe.Pointer) { var cSize int @@ -100,311 +27,9 @@ func TaosWriteRawBlock(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, return int(C.taos_write_raw_block(conn, (C.int)(numOfRows), (*C.char)(pData), cStr)) } -func IsVarDataType(colType uint8) bool { - return colType == common.TSDB_DATA_TYPE_BINARY || colType == common.TSDB_DATA_TYPE_NCHAR || colType == common.TSDB_DATA_TYPE_JSON -} - -func BitmapLen(n int) int { - return ((n) + ((1 << 3) - 1)) >> 3 -} - -func BitPos(n int) int { - return n & ((1 << 3) - 1) -} - -func CharOffset(n int) int { - return n >> 3 -} - -func BMIsNull(c byte, n int) bool { - return c&(1<<(7-BitPos(n))) == (1 << (7 - BitPos(n))) -} - -type rawConvertFunc func(pStart uintptr, row int, arg ...interface{}) driver.Value - -type rawConvertVarDataFunc func(pHeader, pStart uintptr, row int) driver.Value - -var rawConvertFuncMap = map[uint8]rawConvertFunc{ - uint8(common.TSDB_DATA_TYPE_BOOL): rawConvertBool, - uint8(common.TSDB_DATA_TYPE_TINYINT): rawConvertTinyint, - uint8(common.TSDB_DATA_TYPE_SMALLINT): rawConvertSmallint, - uint8(common.TSDB_DATA_TYPE_INT): rawConvertInt, - uint8(common.TSDB_DATA_TYPE_BIGINT): rawConvertBigint, - uint8(common.TSDB_DATA_TYPE_UTINYINT): rawConvertUTinyint, - uint8(common.TSDB_DATA_TYPE_USMALLINT): rawConvertUSmallint, - uint8(common.TSDB_DATA_TYPE_UINT): rawConvertUInt, - uint8(common.TSDB_DATA_TYPE_UBIGINT): rawConvertUBigint, - uint8(common.TSDB_DATA_TYPE_FLOAT): rawConvertFloat, - uint8(common.TSDB_DATA_TYPE_DOUBLE): rawConvertDouble, - uint8(common.TSDB_DATA_TYPE_TIMESTAMP): rawConvertTime, -} - -var rawConvertVarDataMap = map[uint8]rawConvertVarDataFunc{ - uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, - uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, - uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, -} - -func ItemIsNull(pHeader uintptr, row int) bool { - offset := CharOffset(row) - c := *((*byte)(unsafe.Pointer(pHeader + uintptr(offset)))) - if BMIsNull(c, row) { - return true - } - return false -} - -func rawConvertBool(pStart uintptr, row int, _ ...interface{}) driver.Value { - if (*((*byte)(unsafe.Pointer(pStart + uintptr(row)*1)))) != 0 { - return true - } else { - return false - } -} - -func rawConvertTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int8)(unsafe.Pointer(pStart + uintptr(row)*Int8Size))) -} - -func rawConvertSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int16)(unsafe.Pointer(pStart + uintptr(row)*Int16Size))) -} - -func rawConvertInt(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int32)(unsafe.Pointer(pStart + uintptr(row)*Int32Size))) -} - -func rawConvertBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))) -} - -func rawConvertUTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint8)(unsafe.Pointer(pStart + uintptr(row)*UInt8Size))) -} - -func rawConvertUSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint16)(unsafe.Pointer(pStart + uintptr(row)*UInt16Size))) -} - -func rawConvertUInt(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint32)(unsafe.Pointer(pStart + uintptr(row)*UInt32Size))) -} - -func rawConvertUBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint64)(unsafe.Pointer(pStart + uintptr(row)*UInt64Size))) -} - -func rawConvertFloat(pStart uintptr, row int, _ ...interface{}) driver.Value { - return math.Float32frombits(*((*uint32)(unsafe.Pointer(pStart + uintptr(row)*Float32Size)))) -} - -func rawConvertDouble(pStart uintptr, row int, _ ...interface{}) driver.Value { - return math.Float64frombits(*((*uint64)(unsafe.Pointer(pStart + uintptr(row)*Float64Size)))) -} - -func rawConvertTime(pStart uintptr, row int, arg ...interface{}) driver.Value { - if len(arg) == 1 { - return common.TimestampConvertToTime(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) - } else if len(arg) == 2 { - return arg[1].(FormatTimeFunc)(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) - } else { - panic("convertTime error") - } -} - -func rawConvertBinary(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) - if offset == -1 { - return nil - } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := int16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) - } - return string(binaryVal[:]) -} - -func rawConvertNchar(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) - if offset == -1 { - return nil - } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) / 4 - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]rune, clen) - - for index := int16(0); index < clen; index++ { - binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) - } - return string(binaryVal) -} - -// just like nchar -func rawConvertJson(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) - if offset == -1 { - return nil - } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := int16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) - } - return binaryVal[:] -} - -// ReadBlock in-place -func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int) [][]driver.Value { - r := make([][]driver.Value, blockSize) - colCount := len(colTypes) - nullBitMapOffset := uintptr(BitmapLen(blockSize)) - lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - pStart := pHeader - for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) - if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + Int32Size*uintptr(blockSize) - for row := 0; row < blockSize; row++ { - if column == 0 { - r[row] = make([]driver.Value, colCount) - } - r[row][column] = convertF(pHeader, pStart, row) - } - } else { - convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset - for row := 0; row < blockSize; row++ { - if column == 0 { - r[row] = make([]driver.Value, colCount) - } - if ItemIsNull(pHeader, row) { - r[row][column] = nil - } else { - r[row][column] = convertF(pStart, row, precision) - } - } - } - pHeader = pStart + uintptr(colLength) - } - return r -} - -func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, colTypes []uint8, precision int) { - colCount := len(colTypes) - nullBitMapOffset := uintptr(BitmapLen(blockSize)) - lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - pStart := pHeader - for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) - if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + Int32Size*uintptr(blockSize) - dest[column] = convertF(pHeader, pStart, row) - } else { - convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset - if ItemIsNull(pHeader, row) { - dest[column] = nil - } else { - dest[column] = convertF(pStart, row, precision) - } - } - pHeader = pStart + uintptr(colLength) - } -} - -func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int, formatFunc FormatTimeFunc) [][]driver.Value { - r := make([][]driver.Value, blockSize) - colCount := len(colTypes) - nullBitMapOffset := uintptr(BitmapLen(blockSize)) - lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - pStart := pHeader - for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) - if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + uintptr(4*blockSize) - for row := 0; row < blockSize; row++ { - if column == 0 { - r[row] = make([]driver.Value, colCount) - } - r[row][column] = convertF(pHeader, pStart, row) - } - } else { - convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset - for row := 0; row < blockSize; row++ { - if column == 0 { - r[row] = make([]driver.Value, colCount) - } - if ItemIsNull(pHeader, row) { - r[row][column] = nil - } else { - r[row][column] = convertF(pStart, row, precision, formatFunc) - } - } - } - pHeader = pStart + uintptr(colLength) - } - return r -} - -func ItemRawBlock(colType uint8, pHeader, pStart uintptr, row int, precision int, timeFormat FormatTimeFunc) driver.Value { - if IsVarDataType(colType) { - switch colType { - case uint8(common.TSDB_DATA_TYPE_BINARY): - return rawConvertBinary(pHeader, pStart, row) - case uint8(common.TSDB_DATA_TYPE_NCHAR): - return rawConvertNchar(pHeader, pStart, row) - case uint8(common.TSDB_DATA_TYPE_JSON): - return rawConvertJson(pHeader, pStart, row) - } - } else { - if ItemIsNull(pHeader, row) { - return nil - } else { - switch colType { - case uint8(common.TSDB_DATA_TYPE_BOOL): - return rawConvertBool(pStart, row) - case uint8(common.TSDB_DATA_TYPE_TINYINT): - return rawConvertTinyint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_SMALLINT): - return rawConvertSmallint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_INT): - return rawConvertInt(pStart, row) - case uint8(common.TSDB_DATA_TYPE_BIGINT): - return rawConvertBigint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UTINYINT): - return rawConvertUTinyint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_USMALLINT): - return rawConvertUSmallint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UINT): - return rawConvertUInt(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UBIGINT): - return rawConvertUBigint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_FLOAT): - return rawConvertFloat(pStart, row) - case uint8(common.TSDB_DATA_TYPE_DOUBLE): - return rawConvertDouble(pStart, row) - case uint8(common.TSDB_DATA_TYPE_TIMESTAMP): - return rawConvertTime(pStart, row, precision, timeFormat) - } - } - } - return nil +// TaosWriteRawBlockWithFields DLL_EXPORT int taos_write_raw_block_with_fields(TAOS* taos, int rows, char* pData, const char* tbname, TAOS_FIELD *fields, int numFields); +func TaosWriteRawBlockWithFields(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, fields unsafe.Pointer, numFields int) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_fields(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (*C.struct_taosField)(fields), (C.int)(numFields))) } diff --git a/wrapper/block_test.go b/wrapper/block_test.go index 1ac8271..e3b7b63 100644 --- a/wrapper/block_test.go +++ b/wrapper/block_test.go @@ -6,16 +6,12 @@ import ( "math" "testing" "time" - "unsafe" "github.com/stretchr/testify/assert" - "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" ) -// @author: xftan -// @date: 2022/4/16 15:12 -// @description: test for read raw block func TestReadBlock(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -137,7 +133,7 @@ func TestReadBlock(t *testing.T) { if blockSize == 0 { break } - d := ReadBlock(block, blockSize, rh.ColTypes, precision) + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) data = append(data, d...) } TaosFreeResult(res) @@ -182,7 +178,7 @@ func TestReadBlock(t *testing.T) { assert.Equal(t, []byte(`{"a":1}`), row3[14].([]byte)) } -func TestReadBlock2(t *testing.T) { +func TestTaosWriteRawBlock(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { t.Error(err) @@ -190,9 +186,18 @@ func TestReadBlock2(t *testing.T) { } defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) defer func() { - res := TaosQuery(conn, "drop database if exists test_block_raw") - code := TaosError(res) + res = TaosQuery(conn, "drop database if exists test_write_block_raw") + code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) TaosFreeResult(res) @@ -201,16 +206,7 @@ func TestReadBlock2(t *testing.T) { } TaosFreeResult(res) }() - res := TaosQuery(conn, "create database if not exists test_block_raw") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - res = TaosQuery(conn, "drop table if exists test_block_raw.all_type2") + res = TaosQuery(conn, "create database test_write_block_raw") code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -219,7 +215,8 @@ func TestReadBlock2(t *testing.T) { return } TaosFreeResult(res) - res = TaosQuery(conn, "create table if not exists test_block_raw.all_type2 (ts timestamp,"+ + + res = TaosQuery(conn, "create table if not exists test_write_block_raw.all_type (ts timestamp,"+ "c1 bool,"+ "c2 tinyint,"+ "c3 smallint,"+ @@ -233,7 +230,7 @@ func TestReadBlock2(t *testing.T) { "c11 double,"+ "c12 binary(20),"+ "c13 nchar(20)"+ - ")") + ") tags (info json)") code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -244,7 +241,7 @@ func TestReadBlock2(t *testing.T) { TaosFreeResult(res) now := time.Now() after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into test_block_raw.all_type2 values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + sql := fmt.Sprintf("insert into test_write_block_raw.t0 using test_write_block_raw.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -255,7 +252,7 @@ func TestReadBlock2(t *testing.T) { } TaosFreeResult(res) - sql = "select * from test_block_raw.all_type2" + sql = "create table test_write_block_raw.t1 using test_write_block_raw.all_type tags('{\"a\":2}')" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -264,130 +261,10 @@ func TestReadBlock2(t *testing.T) { t.Error(errors.NewError(code, errStr)) return } - fileCount := TaosNumFields(res) - rh, err := ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } - precision := TaosResultPrecision(res) - pHeaderList := make([]uintptr, fileCount) - pStartList := make([]uintptr, fileCount) - var data [][]driver.Value - for { - blockSize, errCode, block := TaosFetchRawBlock(res) - if errCode != int(errors.SUCCESS) { - errStr := TaosErrorStr(res) - err := errors.NewError(code, errStr) - t.Error(err) - TaosFreeResult(res) - return - } - if blockSize == 0 { - break - } - nullBitMapOffset := uintptr(BitmapLen(blockSize)) - lengthOffset := RawBlockGetColumnLengthOffset(fileCount) - tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) - tmpPStart := tmpPHeader - for column := 0; column < fileCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) - if IsVarDataType(rh.ColTypes[column]) { - pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) - pStartList[column] = tmpPStart - } else { - pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + nullBitMapOffset - pStartList[column] = tmpPStart - } - tmpPHeader = tmpPStart + uintptr(colLength) - } - for row := 0; row < blockSize; row++ { - rowV := make([]driver.Value, fileCount) - for column := 0; column < fileCount; column++ { - v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { - return common.TimestampConvertToTime(ts, precision) - }) - rowV[column] = v - } - data = append(data, rowV) - } - } - TaosFreeResult(res) - assert.Equal(t, 2, len(data)) - row1 := data[0] - assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) - assert.Equal(t, true, row1[1].(bool)) - assert.Equal(t, int8(1), row1[2].(int8)) - assert.Equal(t, int16(1), row1[3].(int16)) - assert.Equal(t, int32(1), row1[4].(int32)) - assert.Equal(t, int64(1), row1[5].(int64)) - assert.Equal(t, uint8(1), row1[6].(uint8)) - assert.Equal(t, uint16(1), row1[7].(uint16)) - assert.Equal(t, uint32(1), row1[8].(uint32)) - assert.Equal(t, uint64(1), row1[9].(uint64)) - assert.Equal(t, float32(1), row1[10].(float32)) - assert.Equal(t, float64(1), row1[11].(float64)) - assert.Equal(t, "test_binary", row1[12].(string)) - assert.Equal(t, "test_nchar", row1[13].(string)) - row2 := data[1] - assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) - for i := 1; i < 14; i++ { - assert.Nil(t, row2[i]) - } -} - -func TestBlockTag(t *testing.T) { - conn, err := TaosConnect("", "root", "taosdata", "", 0) - if err != nil { - t.Error(err) - return - } - - defer TaosClose(conn) - defer func() { - res := TaosQuery(conn, "drop database if exists test_block_abc1") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - }() - res := TaosQuery(conn, "create database if not exists test_block_abc1") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - res = TaosQuery(conn, "use test_block_abc1") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - res = TaosQuery(conn, "create table if not exists meters(ts timestamp, v int) tags(location varchar(16))") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } TaosFreeResult(res) - res = TaosQuery(conn, "create table if not exists tb1 using meters tags('abcd')") + sql = "use test_write_block_raw" + res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -397,7 +274,7 @@ func TestBlockTag(t *testing.T) { } TaosFreeResult(res) - sql := "select distinct tbname,location from meters;" + sql = "select * from test_write_block_raw.t0" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -406,21 +283,11 @@ func TestBlockTag(t *testing.T) { t.Error(errors.NewError(code, errStr)) return } - fileCount := TaosNumFields(res) - rh, err := ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } - precision := TaosResultPrecision(res) - pHeaderList := make([]uintptr, fileCount) - pStartList := make([]uintptr, fileCount) - var data [][]driver.Value for { blockSize, errCode, block := TaosFetchRawBlock(res) if errCode != int(errors.SUCCESS) { errStr := TaosErrorStr(res) - err := errors.NewError(code, errStr) + err := errors.NewError(errCode, errStr) t.Error(err) TaosFreeResult(res) return @@ -428,249 +295,19 @@ func TestBlockTag(t *testing.T) { if blockSize == 0 { break } - nullBitMapOffset := uintptr(BitmapLen(blockSize)) - lengthOffset := RawBlockGetColumnLengthOffset(fileCount) - tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) // length i32, group u64 - tmpPStart := tmpPHeader - for column := 0; column < fileCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) - if IsVarDataType(rh.ColTypes[column]) { - pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) - pStartList[column] = tmpPStart - } else { - pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + nullBitMapOffset - pStartList[column] = tmpPStart - } - tmpPHeader = tmpPStart + uintptr(colLength) - } - for row := 0; row < blockSize; row++ { - rowV := make([]driver.Value, fileCount) - for column := 0; column < fileCount; column++ { - v := ItemRawBlock(rh.ColTypes[column], pHeaderList[column], pStartList[column], row, precision, func(ts int64, precision int) driver.Value { - return common.TimestampConvertToTime(ts, precision) - }) - rowV[column] = v - } - data = append(data, rowV) - } - } - TaosFreeResult(res) - t.Log(data) - t.Log(len(data[0][1].(string))) -} - -func TestReadRow(t *testing.T) { - conn, err := TaosConnect("", "root", "taosdata", "", 0) - if err != nil { - t.Error(err) - return - } - - defer TaosClose(conn) - res := TaosQuery(conn, "drop database if exists test_read_row") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - defer func() { - res = TaosQuery(conn, "drop database if exists test_read_row") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - }() - res = TaosQuery(conn, "create database test_read_row") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - res = TaosQuery(conn, "create table if not exists test_read_row.all_type (ts timestamp,"+ - "c1 bool,"+ - "c2 tinyint,"+ - "c3 smallint,"+ - "c4 int,"+ - "c5 bigint,"+ - "c6 tinyint unsigned,"+ - "c7 smallint unsigned,"+ - "c8 int unsigned,"+ - "c9 bigint unsigned,"+ - "c10 float,"+ - "c11 double,"+ - "c12 binary(20),"+ - "c13 nchar(20)"+ - ") tags (info json)") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - now := time.Now() - after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into test_read_row.t0 using test_read_row.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) - res = TaosQuery(conn, sql) - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - sql = "select * from test_read_row.all_type" - res = TaosQuery(conn, sql) - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - fileCount := TaosNumFields(res) - rh, err := ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } - precision := TaosResultPrecision(res) - var data [][]driver.Value - for { - blockSize, errCode, block := TaosFetchRawBlock(res) + errCode = TaosWriteRawBlock(conn, blockSize, block, "t1") if errCode != int(errors.SUCCESS) { - errStr := TaosErrorStr(res) - err := errors.NewError(code, errStr) + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) t.Error(err) TaosFreeResult(res) return } - if blockSize == 0 { - break - } - for i := 0; i < blockSize; i++ { - tmp := make([]driver.Value, fileCount) - ReadRow(tmp, block, blockSize, i, rh.ColTypes, precision) - data = append(data, tmp) - } - } - TaosFreeResult(res) - assert.Equal(t, 2, len(data)) - row1 := data[0] - assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) - assert.Equal(t, true, row1[1].(bool)) - assert.Equal(t, int8(1), row1[2].(int8)) - assert.Equal(t, int16(1), row1[3].(int16)) - assert.Equal(t, int32(1), row1[4].(int32)) - assert.Equal(t, int64(1), row1[5].(int64)) - assert.Equal(t, uint8(1), row1[6].(uint8)) - assert.Equal(t, uint16(1), row1[7].(uint16)) - assert.Equal(t, uint32(1), row1[8].(uint32)) - assert.Equal(t, uint64(1), row1[9].(uint64)) - assert.Equal(t, float32(1), row1[10].(float32)) - assert.Equal(t, float64(1), row1[11].(float64)) - assert.Equal(t, "test_binary", row1[12].(string)) - assert.Equal(t, "test_nchar", row1[13].(string)) - assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) - row2 := data[1] - assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) - for i := 1; i < 14; i++ { - assert.Nil(t, row2[i]) - } - assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) -} - -func TestReadBlockWithTimeFormat(t *testing.T) { - conn, err := TaosConnect("", "root", "taosdata", "", 0) - if err != nil { - t.Error(err) - return - } - - defer TaosClose(conn) - res := TaosQuery(conn, "drop database if exists test_read_block_tf") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - defer func() { - res = TaosQuery(conn, "drop database if exists test_read_block_tf") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - }() - res = TaosQuery(conn, "create database test_read_block_tf") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - res = TaosQuery(conn, "create table if not exists test_read_block_tf.all_type (ts timestamp,"+ - "c1 bool,"+ - "c2 tinyint,"+ - "c3 smallint,"+ - "c4 int,"+ - "c5 bigint,"+ - "c6 tinyint unsigned,"+ - "c7 smallint unsigned,"+ - "c8 int unsigned,"+ - "c9 bigint unsigned,"+ - "c10 float,"+ - "c11 double,"+ - "c12 binary(20),"+ - "c13 nchar(20)"+ - ") tags (info json)") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - now := time.Now() - after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into test_read_block_tf.t0 using test_read_block_tf.all_type tags('{\"a\":1}') values('%s',false,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) - res = TaosQuery(conn, sql) - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return } TaosFreeResult(res) - sql = "select * from test_read_block_tf.all_type" + sql = "select * from test_write_block_raw.t1" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -699,15 +336,15 @@ func TestReadBlockWithTimeFormat(t *testing.T) { if blockSize == 0 { break } - data = ReadBlockWithTimeFormat(block, blockSize, rh.ColTypes, precision, func(ts int64, precision int) driver.Value { - return common.TimestampConvertToTime(ts, precision) - }) + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) } TaosFreeResult(res) + assert.Equal(t, 2, len(data)) row1 := data[0] assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) - assert.Equal(t, false, row1[1].(bool)) + assert.Equal(t, true, row1[1].(bool)) assert.Equal(t, int8(1), row1[2].(int8)) assert.Equal(t, int16(1), row1[3].(int16)) assert.Equal(t, int32(1), row1[4].(int32)) @@ -720,16 +357,14 @@ func TestReadBlockWithTimeFormat(t *testing.T) { assert.Equal(t, float64(1), row1[11].(float64)) assert.Equal(t, "test_binary", row1[12].(string)) assert.Equal(t, "test_nchar", row1[13].(string)) - assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) row2 := data[1] assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) for i := 1; i < 14; i++ { assert.Nil(t, row2[i]) } - assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) } -func TestTaosWriteRawBlock(t *testing.T) { +func TestTaosWriteRawBlockWithFields(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { t.Error(err) @@ -737,7 +372,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } defer TaosClose(conn) - res := TaosQuery(conn, "drop database if exists test_write_block_raw") + res := TaosQuery(conn, "drop database if exists test_write_block_raw_fields") code := TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -746,18 +381,18 @@ func TestTaosWriteRawBlock(t *testing.T) { return } TaosFreeResult(res) - defer func() { - res = TaosQuery(conn, "drop database if exists test_write_block_raw") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - }() - res = TaosQuery(conn, "create database test_write_block_raw") + //defer func() { + // res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields") + // code = TaosError(res) + // if code != 0 { + // errStr := TaosErrorStr(res) + // TaosFreeResult(res) + // t.Error(errors.NewError(code, errStr)) + // return + // } + // TaosFreeResult(res) + //}() + res = TaosQuery(conn, "create database test_write_block_raw_fields") code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -767,7 +402,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } TaosFreeResult(res) - res = TaosQuery(conn, "create table if not exists test_write_block_raw.all_type (ts timestamp,"+ + res = TaosQuery(conn, "create table if not exists test_write_block_raw_fields.all_type (ts timestamp,"+ "c1 bool,"+ "c2 tinyint,"+ "c3 smallint,"+ @@ -792,7 +427,7 @@ func TestTaosWriteRawBlock(t *testing.T) { TaosFreeResult(res) now := time.Now() after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into test_write_block_raw.t0 using test_write_block_raw.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + sql := fmt.Sprintf("insert into test_write_block_raw_fields.t0 using test_write_block_raw_fields.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -803,7 +438,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } TaosFreeResult(res) - sql = fmt.Sprintf("create table test_write_block_raw.t1 using test_write_block_raw.all_type tags('{\"a\":2}')") + sql = "create table test_write_block_raw_fields.t1 using test_write_block_raw_fields.all_type tags('{\"a\":2}')" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -814,7 +449,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } TaosFreeResult(res) - sql = "use test_write_block_raw" + sql = "use test_write_block_raw_fields" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -825,7 +460,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } TaosFreeResult(res) - sql = "select * from test_write_block_raw.t0" + sql = "select ts,c1 from test_write_block_raw_fields.t0" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -846,8 +481,10 @@ func TestTaosWriteRawBlock(t *testing.T) { if blockSize == 0 { break } + fieldsCount := TaosNumFields(res) + fields := TaosFetchFields(res) - errCode = TaosWriteRawBlock(conn, blockSize, block, "t1") + errCode = TaosWriteRawBlockWithFields(conn, blockSize, block, "t1", fields, fieldsCount) if errCode != int(errors.SUCCESS) { errStr := TaosErrorStr(nil) err := errors.NewError(errCode, errStr) @@ -858,7 +495,7 @@ func TestTaosWriteRawBlock(t *testing.T) { } TaosFreeResult(res) - sql = "select * from test_write_block_raw.t1" + sql = "select * from test_write_block_raw_fields.t1" res = TaosQuery(conn, sql) code = TaosError(res) if code != 0 { @@ -887,7 +524,7 @@ func TestTaosWriteRawBlock(t *testing.T) { if blockSize == 0 { break } - d := ReadBlock(block, blockSize, rh.ColTypes, precision) + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) data = append(data, d...) } TaosFreeResult(res) @@ -896,233 +533,12 @@ func TestTaosWriteRawBlock(t *testing.T) { row1 := data[0] assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) assert.Equal(t, true, row1[1].(bool)) - assert.Equal(t, int8(1), row1[2].(int8)) - assert.Equal(t, int16(1), row1[3].(int16)) - assert.Equal(t, int32(1), row1[4].(int32)) - assert.Equal(t, int64(1), row1[5].(int64)) - assert.Equal(t, uint8(1), row1[6].(uint8)) - assert.Equal(t, uint16(1), row1[7].(uint16)) - assert.Equal(t, uint32(1), row1[8].(uint32)) - assert.Equal(t, uint64(1), row1[9].(uint64)) - assert.Equal(t, float32(1), row1[10].(float32)) - assert.Equal(t, float64(1), row1[11].(float64)) - assert.Equal(t, "test_binary", row1[12].(string)) - assert.Equal(t, "test_nchar", row1[13].(string)) - row2 := data[1] - assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) - for i := 1; i < 14; i++ { - assert.Nil(t, row2[i]) + for i := 2; i < 14; i++ { + assert.Nil(t, row1[i]) } -} - -func TestParseBlock(t *testing.T) { - conn, err := TaosConnect("", "root", "taosdata", "", 0) - if err != nil { - t.Error(err) - return - } - - defer TaosClose(conn) - res := TaosQuery(conn, "drop database if exists parse_block") - code := TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - defer func() { - res = TaosQuery(conn, "drop database if exists parse_block") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - }() - res = TaosQuery(conn, "create database parse_block vgroups 1") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - res = TaosQuery(conn, "create table if not exists parse_block.all_type (ts timestamp,"+ - "c1 bool,"+ - "c2 tinyint,"+ - "c3 smallint,"+ - "c4 int,"+ - "c5 bigint,"+ - "c6 tinyint unsigned,"+ - "c7 smallint unsigned,"+ - "c8 int unsigned,"+ - "c9 bigint unsigned,"+ - "c10 float,"+ - "c11 double,"+ - "c12 binary(20),"+ - "c13 nchar(20)"+ - ") tags (info json)") - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - now := time.Now() - after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into parse_block.t0 using parse_block.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) - res = TaosQuery(conn, sql) - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - TaosFreeResult(res) - - sql = "select * from parse_block.all_type" - res = TaosQuery(conn, sql) - code = TaosError(res) - if code != 0 { - errStr := TaosErrorStr(res) - TaosFreeResult(res) - t.Error(errors.NewError(code, errStr)) - return - } - fileCount := TaosNumFields(res) - rh, err := ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } - precision := TaosResultPrecision(res) - var data [][]driver.Value - for { - blockSize, errCode, block := TaosFetchRawBlock(res) - if errCode != int(errors.SUCCESS) { - errStr := TaosErrorStr(res) - err := errors.NewError(code, errStr) - t.Error(err) - TaosFreeResult(res) - return - } - if blockSize == 0 { - break - } - version := RawBlockGetVersion(block) - assert.Equal(t, int32(1), version) - length := RawBlockGetLength(block) - assert.Equal(t, int32(374), length) - rows := RawBlockGetNumOfRows(block) - assert.Equal(t, int32(2), rows) - columns := RawBlockGetNumOfCols(block) - assert.Equal(t, int32(15), columns) - hasColumnSegment := RawBlockGetHasColumnSegment(block) - assert.Equal(t, int32(-2147483648), hasColumnSegment) - groupId := RawBlockGetGroupID(block) - assert.Equal(t, uint64(0), groupId) - infos := make([]RawBlockColInfo, columns) - RawBlockGetColInfo(block, infos) - assert.Equal( - t, - []RawBlockColInfo{ - { - ColType: 9, - Bytes: 8, - }, - { - ColType: 1, - Bytes: 1, - }, - { - ColType: 2, - Bytes: 1, - }, - { - ColType: 3, - Bytes: 2, - }, - { - ColType: 4, - Bytes: 4, - }, - { - ColType: 5, - Bytes: 8, - }, - { - ColType: 11, - Bytes: 1, - }, - { - ColType: 12, - Bytes: 2, - }, - { - ColType: 13, - Bytes: 4, - }, - { - ColType: 14, - Bytes: 8, - }, - { - ColType: 6, - Bytes: 4, - }, - { - ColType: 7, - Bytes: 8, - }, - { - ColType: 8, - Bytes: 22, - }, - { - ColType: 10, - Bytes: 82, - }, - { - ColType: 15, - Bytes: 16384, - }, - }, - infos, - ) - d := ReadBlock(block, blockSize, rh.ColTypes, precision) - data = append(data, d...) - } - TaosFreeResult(res) - assert.Equal(t, 2, len(data)) - row1 := data[0] - assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) - assert.Equal(t, true, row1[1].(bool)) - assert.Equal(t, int8(1), row1[2].(int8)) - assert.Equal(t, int16(1), row1[3].(int16)) - assert.Equal(t, int32(1), row1[4].(int32)) - assert.Equal(t, int64(1), row1[5].(int64)) - assert.Equal(t, uint8(1), row1[6].(uint8)) - assert.Equal(t, uint16(1), row1[7].(uint16)) - assert.Equal(t, uint32(1), row1[8].(uint32)) - assert.Equal(t, uint64(1), row1[9].(uint64)) - assert.Equal(t, float32(1), row1[10].(float32)) - assert.Equal(t, float64(1), row1[11].(float64)) - assert.Equal(t, "test_binary", row1[12].(string)) - assert.Equal(t, "test_nchar", row1[13].(string)) - assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) row2 := data[1] assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) for i := 1; i < 14; i++ { assert.Nil(t, row2[i]) } - assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) } diff --git a/wrapper/cgo/handle.go b/wrapper/cgo/handle.go index 17bfbb9..586adcf 100644 --- a/wrapper/cgo/handle.go +++ b/wrapper/cgo/handle.go @@ -7,6 +7,7 @@ package cgo import ( "sync" "sync/atomic" + "unsafe" ) // Handle provides a way to pass values that contain Go pointers @@ -20,45 +21,7 @@ import ( // that is large enough to hold the bit pattern of any pointer. The zero // value of a Handle is not valid, and thus is safe to use as a sentinel // in C APIs. -// -// For instance, on the Go side: -// -// package main -// -// /* -// #include // for uintptr_t -// -// extern void MyGoPrint(uintptr_t handle); -// void myprint(uintptr_t handle); -// */ -// import "C" -// import "runtime/cgo" -// -// //export MyGoPrint -// func MyGoPrint(handle C.uintptr_t) { -// h := cgo.Handle(handle) -// val := h.Value().(string) -// println(val) -// h.Delete() -// } -// -// func main() { -// val := "hello Go" -// C.myprint(C.uintptr_t(cgo.NewHandle(val))) -// // Output: hello Go -// } -// -// and on the C side: -// -// #include // for uintptr_t -// -// // A Go function -// extern void MyGoPrint(uintptr_t handle); -// -// // A C function -// void myprint(uintptr_t handle) { -// MyGoPrint(handle); -// } + type Handle uintptr // NewHandle returns a handle for a given value. @@ -77,7 +40,9 @@ func NewHandle(v interface{}) Handle { } handles.Store(h, v) - return Handle(h) + handle := Handle(h) + handlePointers.Store(h, &handle) + return handle } // Value returns the associated Go value for a valid handle. @@ -91,6 +56,14 @@ func (h Handle) Value() interface{} { return v } +func (h Handle) Pointer() unsafe.Pointer { + p, ok := handlePointers.Load(uintptr(h)) + if !ok { + panic("runtime/cgo: misuse of an invalid Handle") + } + return unsafe.Pointer(p.(*Handle)) +} + // Delete invalidates a handle. This method should only be called once // the program no longer needs to pass the handle to C and the C code // no longer has a copy of the handle value. @@ -98,9 +71,11 @@ func (h Handle) Value() interface{} { // The method panics if the handle is invalid. func (h Handle) Delete() { handles.Delete(uintptr(h)) + handlePointers.Delete(uintptr(h)) } var ( - handles = sync.Map{} // map[Handle]interface{} - handleIdx uintptr // atomic + handles = sync.Map{} // map[Handle]interface{} + handlePointers = sync.Map{} // map[Handle]*Handle + handleIdx uintptr // atomic ) diff --git a/wrapper/row_test.go b/wrapper/row_test.go index 6c7e44f..b748c6c 100644 --- a/wrapper/row_test.go +++ b/wrapper/row_test.go @@ -587,6 +587,7 @@ func TestFetchRowAllType(t *testing.T) { t.Error(err) return } + precision := TaosResultPrecision(res) count := 0 result := make([]driver.Value, numFields) for { @@ -596,7 +597,7 @@ func TestFetchRowAllType(t *testing.T) { } count += 1 lengths := FetchLengths(res, numFields) - precision := TaosResultPrecision(rr) + for i := range header.ColTypes { result[i] = FetchRow(rr, i, header.ColTypes[i], lengths[i], precision) } diff --git a/wrapper/schemaless.go b/wrapper/schemaless.go index eb69d9e..bcfb4be 100644 --- a/wrapper/schemaless.go +++ b/wrapper/schemaless.go @@ -25,25 +25,163 @@ const ( ) // TaosSchemalessInsert TAOS_RES *taos_schemaless_insert(TAOS* taos, char* lines[], int numLines, int protocol, int precision); +// Deprecated func TaosSchemalessInsert(taosConnect unsafe.Pointer, lines []string, protocol int, precision string) unsafe.Pointer { - numLines := len(lines) - var cLines = make([]*C.char, numLines) - needFreeList := make([]unsafe.Pointer, numLines) + numLines, cLines, needFree := taosSchemalessInsertParams(lines) defer func() { - for _, p := range needFreeList { + for _, p := range needFree { C.free(p) } }() + return unsafe.Pointer(C.taos_schemaless_insert( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + )) +} + +// TaosSchemalessInsertTTL TAOS_RES *taos_schemaless_insert_ttl(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int32_t ttl) +// Deprecated +func TaosSchemalessInsertTTL(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, ttl int) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_ttl( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + )) +} + +// TaosSchemalessInsertWithReqID TAOS_RES *taos_schemaless_insert_with_reqid(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int64_t reqid); +// Deprecated +func TaosSchemalessInsertWithReqID(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, reqID int64) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_with_reqid( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int64_t)(reqID), + )) +} + +// TaosSchemalessInsertTTLWithReqID TAOS_RES *taos_schemaless_insert_ttl_with_reqid(TAOS *taos, char *lines[], int numLines, int protocol, int precision, int32_t ttl, int64_t reqid) +// Deprecated +func TaosSchemalessInsertTTLWithReqID(taosConnect unsafe.Pointer, lines []string, protocol int, precision string, ttl int, reqID int64) unsafe.Pointer { + numLines, cLines, needFree := taosSchemalessInsertParams(lines) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + return unsafe.Pointer(C.taos_schemaless_insert_ttl_with_reqid( + taosConnect, + (**C.char)(&cLines[0]), + (C.int)(numLines), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + (C.int64_t)(reqID), + )) +} + +func taosSchemalessInsertParams(lines []string) (numLines int, cLines []*C.char, needFree []unsafe.Pointer) { + numLines = len(lines) + cLines = make([]*C.char, numLines) + needFree = make([]unsafe.Pointer, numLines) for i, line := range lines { cLine := C.CString(line) - needFreeList[i] = unsafe.Pointer(cLine) + needFree[i] = unsafe.Pointer(cLine) cLines[i] = cLine } - if len(precision) == 0 { - return C.taos_schemaless_insert(taosConnect, (**C.char)(&cLines[0]), (C.int)(numLines), (C.int)(protocol), (C.int)(TSDB_SML_TIMESTAMP_NOT_CONFIGURED)) - } else { - return C.taos_schemaless_insert(taosConnect, (**C.char)(&cLines[0]), (C.int)(numLines), (C.int)(protocol), (C.int)(exchange(precision))) - } + return +} + +// TaosSchemalessInsertRaw TAOS_RES *taos_schemaless_insert_raw(TAOS* taos, char* lines, int len, int32_t *totalRows, int protocol, int precision); +func TaosSchemalessInsertRaw(taosConnect unsafe.Pointer, lines string, protocol int, precision string) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + )) + return rows, result +} + +// TaosSchemalessInsertRawTTL TAOS_RES *taos_schemaless_insert_raw_ttl(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int32_t ttl); +func TaosSchemalessInsertRawTTL(taosConnect unsafe.Pointer, lines string, protocol int, precision string, ttl int) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw_ttl( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + )) + return rows, result +} + +// TaosSchemalessInsertRawWithReqID TAOS_RES *taos_schemaless_insert_raw_with_reqid(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int64_t reqid); +func TaosSchemalessInsertRawWithReqID(taosConnect unsafe.Pointer, lines string, protocol int, precision string, reqID int64) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := unsafe.Pointer(C.taos_schemaless_insert_raw_with_reqid( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int64_t)(reqID), + )) + return rows, result +} + +// TaosSchemalessInsertRawTTLWithReqID TAOS_RES *taos_schemaless_insert_raw_ttl_with_reqid(TAOS *taos, char *lines, int len, int32_t *totalRows, int protocol, int precision, int32_t ttl, int64_t reqid) +func TaosSchemalessInsertRawTTLWithReqID(taosConnect unsafe.Pointer, lines string, protocol int, precision string, ttl int, reqID int64) (int32, unsafe.Pointer) { + cLine := C.CString(lines) + defer C.free(unsafe.Pointer(cLine)) + var rows int32 + pTotalRows := unsafe.Pointer(&rows) + result := C.taos_schemaless_insert_raw_ttl_with_reqid( + taosConnect, + cLine, + (C.int)(len(lines)), + (*C.int32_t)(pTotalRows), + (C.int)(protocol), + (C.int)(exchange(precision)), + (C.int32_t)(ttl), + (C.int64_t)(reqID), + ) + return rows, result } func exchange(ts string) int { diff --git a/wrapper/schemaless_test.go b/wrapper/schemaless_test.go index 7b6d5cb..6edbbdd 100644 --- a/wrapper/schemaless_test.go +++ b/wrapper/schemaless_test.go @@ -1,6 +1,7 @@ package wrapper_test import ( + "strings" "testing" "time" "unsafe" @@ -13,7 +14,6 @@ func prepareEnv() unsafe.Pointer { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { panic(err) - return nil } res := wrapper.TaosQuery(conn, "create database if not exists test_schemaless_common") if wrapper.TaosError(res) != 0 { @@ -88,7 +88,7 @@ func TestSchemalessTelnet(t *testing.T) { return } wrapper.TaosFreeResult(result) - t.Log("finish ", time.Now().Sub(s)) + t.Log("finish ", time.Since(s)) } // @author: xftan @@ -203,3 +203,450 @@ func TestSchemalessInfluxDB(t *testing.T) { wrapper.TaosFreeResult(result) } } + +func TestSchemalessRawTelnet(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + type in struct { + rows []string + } + data := []in{ + { + rows: []string{"sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0"}, + }, + { + rows: []string{"sys_if_bytes_out 1636626444 1.3E3 host=web01 interface=eth0"}, + }, + } + for _, d := range data { + row := strings.Join(d.rows, "\n") + totalRows, result := wrapper.TaosSchemalessInsertRaw(conn, row, wrapper.OpenTSDBTelnetLineProtocol, "") + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Log(row) + t.Error(errors.NewError(code, errStr)) + return + } + if int(totalRows) != len(d.rows) { + t.Log(row) + t.Errorf("expect rows %d got %d", len(d.rows), totalRows) + } + affected := wrapper.TaosAffectedRows(result) + if affected != len(d.rows) { + t.Log(row) + t.Errorf("expect affected %d got %d", len(d.rows), affected) + } + wrapper.TaosFreeResult(result) + } +} + +func TestSchemalessRawInfluxDB(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + type in struct { + rows []string + precision string + } + data := []in{ + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "ns", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000"}, + precision: "u", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000"}, + precision: "μ", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000"}, + precision: "ms", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800"}, + precision: "s", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 26297280"}, + precision: "m", + }, + { + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 438288"}, + precision: "h", + }, + { + rows: []string{"cpu_value,host=xyzzy,instance=0,type=cpu,type_instance=user value=63843347 1665212955372077566\n"}, + precision: "ns", + }, + } + for _, d := range data { + row := strings.Join(d.rows, "\n") + totalRows, result := wrapper.TaosSchemalessInsertRaw(conn, row, wrapper.InfluxDBLineProtocol, d.precision) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + t.Log(row) + t.Error(errors.NewError(code, errStr)) + return + } + if int(totalRows) != len(d.rows) { + t.Log(row) + t.Errorf("expect rows %d got %d", len(d.rows), totalRows) + } + affected := wrapper.TaosAffectedRows(result) + if affected != len(d.rows) { + t.Log(row) + t.Errorf("expect affected %d got %d", len(d.rows), affected) + } + wrapper.TaosFreeResult(result) + } +} + +func TestTaosSchemalessInsertRawWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + reqID int64 + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + reqID: 1, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + reqID: 2, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837000000000", + rows: 1, + precision: "u", + reqID: 3, + }, + { + name: "4", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837100000000", + rows: 1, + precision: "μ", + reqID: 4, + }, + { + name: "5", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + reqID: 5, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawWithReqID(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.reqID) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + reqID int64 + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + reqID: 1, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + reqID: 2, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837000000000"}, + precision: "u", + reqID: 3, + }, + { + name: "4", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + reqID: 4, + }, + { + name: "5", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + reqID: 5, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertWithReqID(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.reqID) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertTTL(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + ttl int + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + ttl: 1000, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + ttl: 1200, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + ttl: 1400, + }, + { + name: "4", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + ttl: 1600, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertTTL(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.ttl) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertTTLWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + rows []string + precision string + ttl int + reqId int64 + }{ + { + name: "1", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836800000000000"}, + precision: "", + ttl: 1000, + reqId: 1, + }, + { + name: "2", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577836900000000000"}, + precision: "ns", + ttl: 1200, + reqId: 2, + }, + { + name: "3", + rows: []string{"measurement,host=host1 field1=2i,field2=2.0 1577837100000000"}, + precision: "μ", + ttl: 1400, + reqId: 3, + }, + { + name: "4", + rows: []string{ + "measurement,host=host1 field1=2i,field2=2.0 1577837200000", + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + }, + precision: "ms", + ttl: 1600, + reqId: 4, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + result := wrapper.TaosSchemalessInsertTTLWithReqID(conn, c.rows, wrapper.InfluxDBLineProtocol, c.precision, c.ttl, c.reqId) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertRawTTL(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + ttl int + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + ttl: 1000, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + ttl: 1400, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawTTL(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.ttl) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} + +func TestTaosSchemalessInsertRawTTLWithReqID(t *testing.T) { + conn := prepareEnv() + defer wrapper.TaosClose(conn) + defer cleanEnv(conn) + cases := []struct { + name string + row string + rows int32 + precision string + ttl int + reqID int64 + }{ + { + name: "1", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836800000000000", + rows: 1, + precision: "", + ttl: 1000, + reqID: 1, + }, + { + name: "2", + row: "measurement,host=host1 field1=2i,field2=2.0 1577836900000000000", + rows: 1, + precision: "ns", + ttl: 1200, + reqID: 2, + }, + { + name: "3", + row: "measurement,host=host1 field1=2i,field2=2.0 1577837200000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837300000", + rows: 2, + precision: "ms", + ttl: 1400, + reqID: 3, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rows, result := wrapper.TaosSchemalessInsertRawTTLWithReqID(conn, c.row, wrapper.InfluxDBLineProtocol, c.precision, c.ttl, c.reqID) + if rows != c.rows { + t.Fatal("rows miss") + } + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + t.Fatal(errors.NewError(code, errStr)) + } + wrapper.TaosFreeResult(result) + }) + } +} diff --git a/wrapper/stmt.go b/wrapper/stmt.go index 4a6aa86..cbf5e30 100644 --- a/wrapper/stmt.go +++ b/wrapper/stmt.go @@ -23,6 +23,11 @@ func TaosStmtInit(taosConnect unsafe.Pointer) unsafe.Pointer { return C.taos_stmt_init(taosConnect) } +// TaosStmtInitWithReqID TAOS_STMT *taos_stmt_init_with_reqid(TAOS *taos, int64_t reqid); +func TaosStmtInitWithReqID(taosConn unsafe.Pointer, reqID int64) unsafe.Pointer { + return C.taos_stmt_init_with_reqid(taosConn, (C.int64_t)(reqID)) +} + // TaosStmtPrepare int taos_stmt_prepare(TAOS_STMT *stmt, const char *sql, unsigned long length); func TaosStmtPrepare(stmt unsafe.Pointer, sql string) int { cSql := C.CString(sql) @@ -136,10 +141,9 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. needFreePointer = append(needFreePointer, p) bind.is_null = (*C.char)(p) } else { - switch param.(type) { + switch value := param.(type) { case taosTypes.TaosBool: bind.buffer_type = C.TSDB_DATA_TYPE_BOOL - value := param.(taosTypes.TaosBool) p := C.malloc(1) if value { *(*C.int8_t)(p) = C.int8_t(1) @@ -151,7 +155,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(1) case taosTypes.TaosTinyint: bind.buffer_type = C.TSDB_DATA_TYPE_TINYINT - value := param.(taosTypes.TaosTinyint) p := C.malloc(1) *(*C.int8_t)(p) = C.int8_t(value) needFreePointer = append(needFreePointer, p) @@ -159,14 +162,12 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(1) case taosTypes.TaosSmallint: bind.buffer_type = C.TSDB_DATA_TYPE_SMALLINT - value := param.(taosTypes.TaosSmallint) p := C.malloc(2) *(*C.int16_t)(p) = C.int16_t(value) needFreePointer = append(needFreePointer, p) bind.buffer = p bind.buffer_length = C.uintptr_t(2) case taosTypes.TaosInt: - value := param.(taosTypes.TaosInt) bind.buffer_type = C.TSDB_DATA_TYPE_INT p := C.malloc(4) *(*C.int32_t)(p) = C.int32_t(value) @@ -175,7 +176,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(4) case taosTypes.TaosBigint: bind.buffer_type = C.TSDB_DATA_TYPE_BIGINT - value := param.(taosTypes.TaosBigint) p := C.malloc(8) *(*C.int64_t)(p) = C.int64_t(value) needFreePointer = append(needFreePointer, p) @@ -183,15 +183,13 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(8) case taosTypes.TaosUTinyint: bind.buffer_type = C.TSDB_DATA_TYPE_UTINYINT - buf := param.(taosTypes.TaosUTinyint) cbuf := C.malloc(1) - *(*C.uint8_t)(cbuf) = C.uint8_t(buf) + *(*C.uint8_t)(cbuf) = C.uint8_t(value) needFreePointer = append(needFreePointer, cbuf) bind.buffer = cbuf bind.buffer_length = C.uintptr_t(1) case taosTypes.TaosUSmallint: bind.buffer_type = C.TSDB_DATA_TYPE_USMALLINT - value := param.(taosTypes.TaosUSmallint) p := C.malloc(2) *(*C.uint16_t)(p) = C.uint16_t(value) needFreePointer = append(needFreePointer, p) @@ -199,7 +197,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(2) case taosTypes.TaosUInt: bind.buffer_type = C.TSDB_DATA_TYPE_UINT - value := param.(taosTypes.TaosUInt) p := C.malloc(4) *(*C.uint32_t)(p) = C.uint32_t(value) needFreePointer = append(needFreePointer, p) @@ -207,7 +204,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(4) case taosTypes.TaosUBigint: bind.buffer_type = C.TSDB_DATA_TYPE_UBIGINT - value := param.(taosTypes.TaosUBigint) p := C.malloc(8) *(*C.uint64_t)(p) = C.uint64_t(value) needFreePointer = append(needFreePointer, p) @@ -215,7 +211,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(8) case taosTypes.TaosFloat: bind.buffer_type = C.TSDB_DATA_TYPE_FLOAT - value := param.(taosTypes.TaosFloat) p := C.malloc(4) *(*C.float)(p) = C.float(value) needFreePointer = append(needFreePointer, p) @@ -223,7 +218,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(4) case taosTypes.TaosDouble: bind.buffer_type = C.TSDB_DATA_TYPE_DOUBLE - value := param.(taosTypes.TaosDouble) p := C.malloc(8) *(*C.double)(p) = C.double(value) needFreePointer = append(needFreePointer, p) @@ -231,11 +225,10 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(8) case taosTypes.TaosBinary: bind.buffer_type = C.TSDB_DATA_TYPE_BINARY - buf := param.(taosTypes.TaosBinary) - cbuf := C.CString(string(buf)) + cbuf := C.CString(string(value)) needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) bind.buffer = unsafe.Pointer(cbuf) - clen := int32(len(buf)) + clen := int32(len(value)) p := C.malloc(C.size_t(unsafe.Sizeof(clen))) bind.length = (*C.int32_t)(p) *(bind.length) = C.int32_t(clen) @@ -243,7 +236,6 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(clen) case taosTypes.TaosNchar: bind.buffer_type = C.TSDB_DATA_TYPE_NCHAR - value := param.(taosTypes.TaosNchar) p := unsafe.Pointer(C.CString(string(value))) needFreePointer = append(needFreePointer, p) bind.buffer = unsafe.Pointer(p) @@ -254,8 +246,7 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(clen) case taosTypes.TaosTimestamp: bind.buffer_type = C.TSDB_DATA_TYPE_TIMESTAMP - v := param.(taosTypes.TaosTimestamp) - ts := common.TimeToTimestamp(v.T, v.Precision) + ts := common.TimeToTimestamp(value.T, value.Precision) p := C.malloc(8) needFreePointer = append(needFreePointer, p) *(*C.int64_t)(p) = C.int64_t(ts) @@ -263,11 +254,10 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. bind.buffer_length = C.uintptr_t(8) case taosTypes.TaosJson: bind.buffer_type = C.TSDB_DATA_TYPE_JSON - buf := param.(taosTypes.TaosJson) - cbuf := C.CString(string(buf)) + cbuf := C.CString(string(value)) needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) bind.buffer = unsafe.Pointer(cbuf) - clen := int32(len(buf)) + clen := int32(len(value)) p := C.malloc(C.size_t(unsafe.Sizeof(clen))) bind.length = (*C.int32_t)(p) *(bind.length) = C.int32_t(clen) @@ -302,7 +292,7 @@ func TaosStmtClose(stmt unsafe.Pointer) int { return int(C.taos_stmt_close(stmt)) } -//TaosStmtSetSubTBName int taos_stmt_set_sub_tbname(TAOS_STMT* stmt, const char* name); +// TaosStmtSetSubTBName int taos_stmt_set_sub_tbname(TAOS_STMT* stmt, const char* name); func TaosStmtSetSubTBName(stmt unsafe.Pointer, name string) int { cStr := C.CString(name) defer C.free(unsafe.Pointer(cStr)) @@ -595,6 +585,7 @@ func TaosStmtAffectedRowsOnce(stmt unsafe.Pointer) int { //uint8_t scale; //int32_t bytes; //} TAOS_FIELD_E; + type StmtField struct { Name string FieldType int8 @@ -692,3 +683,8 @@ func StmtParseFields(num int, fields unsafe.Pointer) []*StmtField { } return result } + +// TaosStmtReclaimFields DLL_EXPORT void taos_stmt_reclaim_fields(TAOS_STMT *stmt, TAOS_FIELD_E *fields); +func TaosStmtReclaimFields(stmt unsafe.Pointer, fields unsafe.Pointer) { + C.taos_stmt_reclaim_fields(stmt, (*C.TAOS_FIELD_E)(fields)) +} diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index 1ffce5a..a7db700 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/parser" taosError "github.com/taosdata/driver-go/v3/errors" taosTypes "github.com/taosdata/driver-go/v3/types" ) @@ -628,7 +629,7 @@ func query(conn unsafe.Pointer, sql string) ([][]driver.Value, error) { if columns == 0 { break } - r := ReadBlock(block, columns, rh.ColTypes, precision) + r := parser.ReadBlock(block, columns, rh.ColTypes, precision) result = append(result, r...) } return result, nil @@ -677,7 +678,7 @@ func StmtQuery(t *testing.T, conn unsafe.Pointer, sql string, params *param.Para if blockSize == 0 { break } - d := ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) + d := parser.ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) data = append(data, d...) } TaosFreeResult(res) @@ -759,6 +760,7 @@ func TestGetFields(t *testing.T) { t.Error(err) return } + defer TaosStmtReclaimFields(stmt, tagsP) code, columnCount, columnsP := TaosStmtGetColFields(stmt) if code != 0 { errStr := TaosStmtErrStr(stmt) @@ -766,6 +768,7 @@ func TestGetFields(t *testing.T) { t.Error(err) return } + defer TaosStmtReclaimFields(stmt, columnsP) columns := StmtParseFields(columnCount, columnsP) tags := StmtParseFields(tagCount, tagsP) assert.Equal(t, []*StmtField{ @@ -858,6 +861,7 @@ func TestGetFieldsCommonTable(t *testing.T) { t.Error(err) return } + defer TaosStmtReclaimFields(stmt, columnsP) columns := StmtParseFields(columnCount, columnsP) assert.Equal(t, []*StmtField{ {"ts", 9, 0, 0, 8}, @@ -932,6 +936,7 @@ func TestTaosStmtSetTags(t *testing.T) { stmt := TaosStmtInit(conn) if stmt == nil { err = taosError.NewError(0xffff, "failed to init stmt") + t.Error(err) return } //defer TaosStmtClose(stmt) diff --git a/wrapper/taosc.go b/wrapper/taosc.go index 811c9c2..803ed16 100644 --- a/wrapper/taosc.go +++ b/wrapper/taosc.go @@ -4,7 +4,7 @@ package wrapper #cgo CFLAGS: -IC:/TDengine/include -I/usr/include #cgo linux LDFLAGS: -L/usr/lib -ltaos #cgo windows LDFLAGS: -LC:/TDengine/driver -ltaos -#cgo darwin LDFLAGS: -L/usr/local/taos/driver -ltaos +#cgo darwin LDFLAGS: -L/usr/local/lib -ltaos #include #include #include @@ -21,6 +21,9 @@ void taos_fetch_rows_a_wrapper(TAOS_RES *res, void *param){ void taos_query_a_wrapper(TAOS *taos,const char *sql, void *param){ return taos_query_a(taos,sql,QueryCallback,param); }; +void taos_query_a_with_req_id_wrapper(TAOS *taos,const char *sql, void *param, int64_t reqID){ + return taos_query_a_with_reqid(taos, sql, QueryCallback, param, reqID); +}; void taos_fetch_raw_block_a_wrapper(TAOS_RES *res, void *param){ return taos_fetch_raw_block_a(res,FetchRawBlockCallback,param); }; @@ -78,6 +81,13 @@ func TaosQuery(taosConnect unsafe.Pointer, sql string) unsafe.Pointer { return unsafe.Pointer(C.taos_query(taosConnect, cSql)) } +// TasoQueryWithReqID TAOS_RES *taos_query_with_reqid(TAOS *taos, const char *sql, int64_t reqID); +func TasoQueryWithReqID(taosConn unsafe.Pointer, sql string, reqID int64) unsafe.Pointer { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + return unsafe.Pointer(C.taos_query_with_reqid(taosConn, cSql, (C.int64_t)(reqID))) +} + // TaosError int taos_errno(TAOS_RES *tres); func TaosError(result unsafe.Pointer) int { return int(C.taos_errno(result)) @@ -144,12 +154,19 @@ func TaosOptions(option int, value string) int { func TaosQueryA(taosConnect unsafe.Pointer, sql string, caller cgo.Handle) { cSql := C.CString(sql) defer C.free(unsafe.Pointer(cSql)) - C.taos_query_a_wrapper(taosConnect, cSql, unsafe.Pointer(caller)) + C.taos_query_a_wrapper(taosConnect, cSql, caller.Pointer()) +} + +// TaosQueryAWithReqID void taos_query_a_with_reqid(TAOS *taos, const char *sql, __taos_async_fn_t fp, void *param, int64_t reqid); +func TaosQueryAWithReqID(taosConn unsafe.Pointer, sql string, caller cgo.Handle, reqID int64) { + cSql := C.CString(sql) + defer C.free(unsafe.Pointer(cSql)) + C.taos_query_a_with_req_id_wrapper(taosConn, cSql, caller.Pointer(), (C.int64_t)(reqID)) } // TaosFetchRowsA void taos_fetch_rows_a(TAOS_RES *res, void (*fp)(void *param, TAOS_RES *, int numOfRows), void *param); func TaosFetchRowsA(res unsafe.Pointer, caller cgo.Handle) { - C.taos_fetch_rows_a_wrapper(res, unsafe.Pointer(caller)) + C.taos_fetch_rows_a_wrapper(res, caller.Pointer()) } // TaosResetCurrentDB void taos_reset_current_db(TAOS *taos); @@ -176,7 +193,7 @@ func TaosFetchLengths(res unsafe.Pointer) unsafe.Pointer { // TaosFetchRawBlockA void taos_fetch_raw_block_a(TAOS_RES* res, __taos_async_fn_t fp, void* param); func TaosFetchRawBlockA(res unsafe.Pointer, caller cgo.Handle) { - C.taos_fetch_raw_block_a_wrapper(res, unsafe.Pointer(caller)) + C.taos_fetch_raw_block_a_wrapper(res, caller.Pointer()) } // TaosGetRawBlock const void *taos_get_raw_block(TAOS_RES* res); @@ -196,3 +213,45 @@ func TaosLoadTableInfo(taosConnect unsafe.Pointer, tableNameList []string) int { defer C.free(unsafe.Pointer(buf)) return int(C.taos_load_table_info(taosConnect, buf)) } + +// TaosGetTableVgID +// DLL_EXPORT int taos_get_table_vgId(TAOS *taos, const char *db, const char *table, int *vgId) +func TaosGetTableVgID(conn unsafe.Pointer, db, table string) (vgID int, code int) { + cDB := C.CString(db) + defer C.free(unsafe.Pointer(cDB)) + cTable := C.CString(table) + defer C.free(unsafe.Pointer(cTable)) + + code = int(C.taos_get_table_vgId(conn, cDB, cTable, (*C.int)(unsafe.Pointer(&vgID)))) + return +} + +// TaosGetTablesVgID DLL_EXPORT int taos_get_tables_vgId(TAOS *taos, const char *db, const char *table[], int tableNum, int *vgId) +func TaosGetTablesVgID(conn unsafe.Pointer, db string, tables []string) (vgIDs []int, code int) { + cDB := C.CString(db) + defer C.free(unsafe.Pointer(cDB)) + numTables := len(tables) + cTables := make([]*C.char, numTables) + needFree := make([]unsafe.Pointer, numTables) + defer func() { + for _, p := range needFree { + C.free(p) + } + }() + for i, table := range tables { + cTable := C.CString(table) + needFree[i] = unsafe.Pointer(cTable) + cTables[i] = cTable + } + p := C.malloc(C.sizeof_int * C.size_t(numTables)) + defer C.free(p) + code = int(C.taos_get_tables_vgId(conn, cDB, (**C.char)(&cTables[0]), (C.int)(numTables), (*C.int)(p))) + if code != 0 { + return nil, code + } + vgIDs = make([]int, numTables) + for i := 0; i < numTables; i++ { + vgIDs[i] = int(*(*C.int)(unsafe.Pointer(uintptr(p) + uintptr(C.sizeof_int*C.int(i))))) + } + return +} diff --git a/wrapper/taosc_test.go b/wrapper/taosc_test.go index be464f1..5e41dfa 100644 --- a/wrapper/taosc_test.go +++ b/wrapper/taosc_test.go @@ -2,12 +2,15 @@ package wrapper import ( "database/sql/driver" + "fmt" "io" "testing" + "time" "unsafe" "github.com/stretchr/testify/assert" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper/cgo" ) @@ -408,7 +411,7 @@ func TestTaosResultBlock(t *testing.T) { res = r.res block := TaosGetRawBlock(res) assert.NotNil(t, block) - values := ReadBlock(block, r.n, rowsHeader.ColTypes, precision) + values := parser.ReadBlock(block, r.n, rowsHeader.ColTypes, precision) _ = values t.Log(values) } @@ -459,3 +462,93 @@ func TestTaosLoadTableInfo(t *testing.T) { } } + +func TestTaosGetTableVgID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + } + defer TaosClose(conn) + dbName := "table_vg_id_test" + + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + if err = exec(conn, fmt.Sprintf("create database %s", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create stable %s.meters (ts timestamp, current float, voltage int, phase float) "+ + "tags (location binary(64), groupId int)", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d0 using %s.meters tags ('California.SanFrancisco', 1)", dbName, dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d1 using %s.meters tags ('California.LosAngles', 2)", dbName, dbName)); err != nil { + t.Fatal(err) + } + + vg1, code := TaosGetTableVgID(conn, dbName, "d0") + if code != 0 { + t.Fatal("fail") + } + vg2, code := TaosGetTableVgID(conn, dbName, "d0") + if code != 0 { + t.Fatal("fail") + } + if vg1 != vg2 { + t.Fatal("fail") + } + _, code = TaosGetTableVgID(conn, dbName, "d1") + if code != 0 { + t.Fatal("fail") + } + _, code = TaosGetTableVgID(conn, dbName, "d2") + if code != 0 { + t.Fatal("fail") + } +} + +func TestTaosGetTablesVgID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + } + defer TaosClose(conn) + dbName := "tables_vg_id_test" + + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + if err = exec(conn, fmt.Sprintf("create database %s", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create stable %s.meters (ts timestamp, current float, voltage int, phase float) "+ + "tags (location binary(64), groupId int)", dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d0 using %s.meters tags ('California.SanFrancisco', 1)", dbName, dbName)); err != nil { + t.Fatal(err) + } + if err = exec(conn, fmt.Sprintf("create table %s.d1 using %s.meters tags ('California.LosAngles', 2)", dbName, dbName)); err != nil { + t.Fatal(err) + } + var vgs1 []int + var vgs2 []int + var code int + now := time.Now() + vgs1, code = TaosGetTablesVgID(conn, dbName, []string{"d0", "d1"}) + fmt.Println(time.Since(now)) + if code != 0 { + t.Fatal("fail") + } + assert.Equal(t, 2, len(vgs1)) + vgs2, code = TaosGetTablesVgID(conn, dbName, []string{"d0", "d1"}) + if code != 0 { + t.Fatal("fail") + } + assert.Equal(t, 2, len(vgs2)) + assert.Equal(t, vgs2, vgs1) +} diff --git a/wrapper/tmq.go b/wrapper/tmq.go index 17d15dc..93960ec 100644 --- a/wrapper/tmq.go +++ b/wrapper/tmq.go @@ -67,12 +67,12 @@ func TMQConfDestroy(conf unsafe.Pointer) { // TMQConfSetAutoCommitCB DLL_EXPORT void tmq_conf_set_auto_commit_cb(tmq_conf_t *conf, tmq_commit_cb *cb, void *param); func TMQConfSetAutoCommitCB(conf unsafe.Pointer, h cgo.Handle) { - C.tmq_conf_set_auto_commit_cb((*C.struct_tmq_conf_t)(conf), (*C.tmq_commit_cb)(C.TMQCommitCB), unsafe.Pointer(h)) + C.tmq_conf_set_auto_commit_cb((*C.struct_tmq_conf_t)(conf), (*C.tmq_commit_cb)(C.TMQCommitCB), h.Pointer()) } // TMQCommitAsync DLL_EXPORT void tmq_commit_async(tmq_t *tmq, const TAOS_RES *msg, tmq_commit_cb *cb, void *param); func TMQCommitAsync(consumer unsafe.Pointer, message unsafe.Pointer, h cgo.Handle) { - C.tmq_commit_async((*C.tmq_t)(consumer), message, (*C.tmq_commit_cb)(C.TMQCommitCB), unsafe.Pointer(h)) + C.tmq_commit_async((*C.tmq_t)(consumer), message, (*C.tmq_commit_cb)(C.TMQCommitCB), h.Pointer()) } // TMQCommitSync DLL_EXPORT int32_t tmq_commit_sync(tmq_t *tmq, const TAOS_RES *msg); @@ -119,10 +119,10 @@ func TMQConsumerNew(conf unsafe.Pointer) (unsafe.Pointer, error) { tmq := unsafe.Pointer(C.tmq_consumer_new((*C.struct_tmq_conf_t)(conf), p, C.int32_t(1024))) errStr := C.GoString(p) if len(errStr) > 0 { - return tmq, errors.NewError(-1, errStr) + return nil, errors.NewError(-1, errStr) } if tmq == nil { - panic("new consumer return nil") + return nil, errors.NewError(-1, "new consumer return nil") } return tmq, nil } diff --git a/wrapper/tmq_test.go b/wrapper/tmq_test.go index c4942f5..3d27e64 100644 --- a/wrapper/tmq_test.go +++ b/wrapper/tmq_test.go @@ -9,6 +9,8 @@ import ( jsoniter "github.com/json-iterator/go" "github.com/stretchr/testify/assert" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + tmqcommon "github.com/taosdata/driver-go/v3/common/tmq" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper/cgo" ) @@ -31,7 +33,7 @@ func TestTMQ(t *testing.T) { } TaosFreeResult(result) }() - result := TaosQuery(conn, "create database if not exists abc1 vgroups 2") + result := TaosQuery(conn, "create database if not exists abc1 vgroups 2 WAL_RETENTION_PERIOD 86400") code := TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -147,12 +149,9 @@ func TestTMQ(t *testing.T) { h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - t.Log("auto commit", r) - PutTMQCommitCallbackResult(r) - } + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) } }() tmq, err := TMQConsumerNew(conf) @@ -172,7 +171,7 @@ func TestTMQ(t *testing.T) { t.Error(errors.NewError(int(errCode), errStr)) return } - t.Log("sub", time.Now().Sub(s)) + t.Log("sub", time.Since(s)) errCode, list := TMQSubscription(tmq) if errCode != 0 { errStr := TMQErr2Str(errCode) @@ -213,11 +212,11 @@ func TestTMQ(t *testing.T) { return } precision := TaosResultPrecision(message) - tableName := TMQGetTableName(message) - assert.Equal(t, "ct1", tableName) + //tableName := TMQGetTableName(message) + //assert.Equal(t, "ct1", tableName) dbName := TMQGetDBName(message) assert.Equal(t, "abc1", dbName) - data := ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) t.Log(data) } TaosFreeResult(message) @@ -274,7 +273,7 @@ func TestTMQDB(t *testing.T) { } TaosFreeResult(result) }() - result := TaosQuery(conn, "create database if not exists tmq_test_db vgroups 2") + result := TaosQuery(conn, "create database if not exists tmq_test_db vgroups 2 WAL_RETENTION_PERIOD 86400") code := TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -379,12 +378,9 @@ func TestTMQDB(t *testing.T) { h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - t.Log("auto commit", r) - PutTMQCommitCallbackResult(r) - } + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) } }() tmq, err := TMQConsumerNew(conf) @@ -443,7 +439,7 @@ func TestTMQDB(t *testing.T) { } precision := TaosResultPrecision(message) totalCount += blockSize - data := ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) t.Log(data) } TaosFreeResult(message) @@ -492,7 +488,7 @@ func TestTMQDBMultiTable(t *testing.T) { } TaosFreeResult(result) }() - result := TaosQuery(conn, "create database if not exists tmq_test_db_multi vgroups 2") + result := TaosQuery(conn, "create database if not exists tmq_test_db_multi vgroups 2 WAL_RETENTION_PERIOD 86400") code := TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -606,12 +602,9 @@ func TestTMQDBMultiTable(t *testing.T) { h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - t.Log("auto commit", r) - PutTMQCommitCallbackResult(r) - } + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) } }() tmq, err := TMQConsumerNew(conf) @@ -675,7 +668,7 @@ func TestTMQDBMultiTable(t *testing.T) { } precision := TaosResultPrecision(message) totalCount += blockSize - data := ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) t.Log(data) } TaosFreeResult(message) @@ -729,7 +722,7 @@ func TestTMQDBMultiInsert(t *testing.T) { } TaosFreeResult(result) }() - result := TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2") + result := TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2 WAL_RETENTION_PERIOD 86400") code := TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -821,12 +814,9 @@ func TestTMQDBMultiInsert(t *testing.T) { h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - t.Log("auto commit", r) - PutTMQCommitCallbackResult(r) - } + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) } }() tmq, err := TMQConsumerNew(conf) @@ -887,7 +877,7 @@ func TestTMQDBMultiInsert(t *testing.T) { } precision := TaosResultPrecision(message) totalCount += blockSize - data := ReadBlock(block, blockSize, rh.ColTypes, precision) + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) t.Log(data) } TaosFreeResult(message) @@ -967,7 +957,7 @@ func TestTMQModify(t *testing.T) { } TaosFreeResult(result) - result = TaosQuery(conn, "create database if not exists tmq_test_db_modify_target vgroups 2") + result = TaosQuery(conn, "create database if not exists tmq_test_db_modify_target vgroups 2 WAL_RETENTION_PERIOD 86400") code = TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -977,7 +967,7 @@ func TestTMQModify(t *testing.T) { } TaosFreeResult(result) - result = TaosQuery(conn, "create database if not exists tmq_test_db_modify vgroups 5") + result = TaosQuery(conn, "create database if not exists tmq_test_db_modify vgroups 5 WAL_RETENTION_PERIOD 86400") code = TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -1028,12 +1018,9 @@ func TestTMQModify(t *testing.T) { h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) go func() { - for { - select { - case r := <-c: - t.Log("auto commit", r) - PutTMQCommitCallbackResult(r) - } + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) } }() tmq, err := TMQConsumerNew(conf) @@ -1105,7 +1092,7 @@ func TestTMQModify(t *testing.T) { } TaosFreeResult(result) - pool := func(cb func(*common.Meta, unsafe.Pointer)) { + pool := func(cb func(*tmqcommon.Meta, unsafe.Pointer)) { message := TMQConsumerPoll(tmq, 500) assert.NotNil(t, message) topic := TMQGetTopicName(message) @@ -1115,7 +1102,7 @@ func TestTMQModify(t *testing.T) { pointer := TMQGetJsonMeta(message) assert.NotNil(t, pointer) data := ParseJsonMeta(pointer) - var meta common.Meta + var meta tmqcommon.Meta err = jsoniter.Unmarshal(data, &meta) assert.NoError(t, err) @@ -1144,10 +1131,9 @@ func TestTMQModify(t *testing.T) { } cb(&meta, rawMeta) TMQFreeRaw(rawMeta) - return } - pool(func(meta *common.Meta, rawMeta unsafe.Pointer) { + pool(func(meta *tmqcommon.Meta, rawMeta unsafe.Pointer) { assert.Equal(t, "create", meta.Type) assert.Equal(t, "stb", meta.TableName) assert.Equal(t, "super", meta.TableType) @@ -1202,3 +1188,194 @@ func TestTMQModify(t *testing.T) { return } } + +func TestTMQAutoCreateTable(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer TaosClose(conn) + defer func() { + result := TaosQuery(conn, "drop database if exists tmq_test_auto_create") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result := TaosQuery(conn, "create database if not exists tmq_test_auto_create vgroups 2 WAL_RETENTION_PERIOD 86400") + code := TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "use tmq_test_auto_create") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + result = TaosQuery(conn, "create stable if not exists st1 (ts timestamp, c1 int, c2 float, c3 binary(10)) tags(t1 int)") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + + //create topic + result = TaosQuery(conn, "create topic if not exists test_tmq_auto_topic with meta as DATABASE tmq_test_auto_create") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + defer func() { + result = TaosQuery(conn, "drop topic if exists test_tmq_auto_topic") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + }() + result = TaosQuery(conn, "insert into ct1 using st1 tags(2000) values(now,1,2,'1')") + code = TaosError(result) + if code != 0 { + errStr := TaosErrorStr(result) + TaosFreeResult(result) + t.Error(errors.TaosError{Code: int32(code), ErrStr: errStr}) + return + } + TaosFreeResult(result) + //build consumer + conf := TMQConfNew() + // auto commit default is true then the commitCallback function will be called after 5 seconds + TMQConfSet(conf, "enable.auto.commit", "true") + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "msg.with.table.name", "true") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Error(err) + } + TMQConfDestroy(conf) + //build_topic_list + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_auto_topic") + + //sync_consume_loop + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + errCode, list := TMQSubscription(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + size := TMQListGetSize(list) + r := TMQListToCArray(list, int(size)) + assert.Equal(t, []string{"test_tmq_auto_topic"}, r) + totalCount := 0 + c2 := make(chan *TMQCommitCallbackResult, 1) + h2 := cgo.NewHandle(c2) + for i := 0; i < 5; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + t.Log(message) + topic := TMQGetTopicName(message) + assert.Equal(t, "test_tmq_auto_topic", topic) + messageType := TMQGetResType(message) + if messageType != common.TMQ_RES_METADATA { + continue + } + pointer := TMQGetJsonMeta(message) + data := ParseJsonMeta(pointer) + t.Log(string(data)) + var meta tmqcommon.Meta + err = jsoniter.Unmarshal(data, &meta) + assert.NoError(t, err) + assert.Equal(t, "create", meta.Type) + for { + blockSize, errCode, block := TaosFetchRawBlock(message) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(message) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(message) + return + } + if blockSize == 0 { + break + } + tableName := TMQGetTableName(message) + assert.Equal(t, "ct1", tableName) + filedCount := TaosNumFields(message) + rh, err := ReadColumn(message, filedCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(message) + totalCount += blockSize + data := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + t.Log(data) + } + TaosFreeResult(message) + + TMQCommitAsync(tmq, nil, h2) + timer := time.NewTimer(time.Minute) + select { + case d := <-c2: + assert.Nil(t, d.GetError()) + assert.Equal(t, int32(0), d.ErrCode) + PutTMQCommitCallbackResult(d) + timer.Stop() + break + case <-timer.C: + timer.Stop() + t.Error("wait tmq commit callback timeout") + return + } + } + } + + errCode = TMQConsumerClose(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } + assert.GreaterOrEqual(t, totalCount, 1) +} diff --git a/wrapper/tmqcb.go b/wrapper/tmqcb.go index 52cb3d2..82ae1ee 100644 --- a/wrapper/tmqcb.go +++ b/wrapper/tmqcb.go @@ -17,7 +17,7 @@ import ( //export TMQCommitCB func TMQCommitCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { - c := cgo.Handle(param).Value().(chan *TMQCommitCallbackResult) + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) r := GetTMQCommitCallbackResult(int32(resp), consumer) defer func() { // Avoid panic due to channel closed diff --git a/ws/client/conn.go b/ws/client/conn.go new file mode 100644 index 0000000..493a7f1 --- /dev/null +++ b/ws/client/conn.go @@ -0,0 +1,181 @@ +package client + +import ( + "bytes" + "encoding/json" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" + "github.com/taosdata/driver-go/v3/common" +) + +const ( + StatusNormal = uint32(1) + StatusStop = uint32(2) +) + +var JsonI = jsoniter.ConfigCompatibleWithStandardLibrary + +type WSAction struct { + Action string `json:"action"` + Args json.RawMessage `json:"args"` +} + +var GlobalEnvelopePool EnvelopePool + +type EnvelopePool struct { + p sync.Pool +} + +func (ep *EnvelopePool) Get() *Envelope { + epv := ep.p.Get() + if epv == nil { + return &Envelope{Msg: new(bytes.Buffer)} + } + return epv.(*Envelope) +} + +func (ep *EnvelopePool) Put(epv *Envelope) { + epv.Reset() + ep.p.Put(epv) +} + +type Envelope struct { + Type int + Msg *bytes.Buffer +} + +func (e *Envelope) Reset() { + e.Msg.Reset() +} + +type Client struct { + conn *websocket.Conn + status uint32 + sendChan chan *Envelope + BufferSize int + WriteWait time.Duration + PingPeriod time.Duration + PongWait time.Duration + TextMessageHandler func(message []byte) + BinaryMessageHandler func(message []byte) + ErrorHandler func(err error) + SendMessageHandler func(envelope *Envelope) + once sync.Once + errHandlerOnce sync.Once +} + +func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { + return &Client{ + conn: conn, + status: StatusNormal, + BufferSize: common.BufferSize4M, + sendChan: make(chan *Envelope, sendChanLength), + WriteWait: common.DefaultWriteWait, + PingPeriod: common.DefaultPingPeriod, + PongWait: common.DefaultPongWait, + TextMessageHandler: func(message []byte) {}, + BinaryMessageHandler: func(message []byte) {}, + ErrorHandler: func(err error) {}, + SendMessageHandler: func(envelope *Envelope) { + GlobalEnvelopePool.Put(envelope) + }, + } +} + +func (c *Client) ReadPump() { + c.conn.SetReadLimit(common.BufferSize4M) + c.conn.SetReadDeadline(time.Now().Add(c.PongWait)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(c.PongWait)) + return nil + }) + c.conn.SetCloseHandler(nil) + for { + messageType, message, err := c.conn.ReadMessage() + if err != nil { + if e, ok := err.(*websocket.CloseError); ok && e.Code == websocket.CloseAbnormalClosure { + break + } + c.handleError(err) + break + } + switch messageType { + case websocket.TextMessage: + c.TextMessageHandler(message) + case websocket.BinaryMessage: + c.BinaryMessageHandler(message) + } + } +} + +func (c *Client) WritePump() { + ticker := time.NewTicker(c.PingPeriod) + defer func() { + ticker.Stop() + }() + for { + select { + case message, ok := <-c.sendChan: + if !ok { + return + } + c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) + err := c.conn.WriteMessage(message.Type, message.Msg.Bytes()) + if err != nil { + c.handleError(err) + return + } + c.SendMessageHandler(message) + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + c.handleError(err) + return + } + } + } +} + +func (c *Client) Send(envelope *Envelope) { + if !c.IsRunning() { + return + } + defer func() { + // maybe closed + if recover() != nil { + + return + } + }() + c.sendChan <- envelope +} + +func (c *Client) GetEnvelope() *Envelope { + return GlobalEnvelopePool.Get() +} + +func (c *Client) PutEnvelope(envelope *Envelope) { + GlobalEnvelopePool.Put(envelope) +} + +func (c *Client) IsRunning() bool { + return atomic.LoadUint32(&c.status) == StatusNormal +} + +func (c *Client) Close() { + c.once.Do(func() { + close(c.sendChan) + atomic.StoreUint32(&c.status, StatusStop) + if c.conn != nil { + c.conn.Close() + } + }) +} + +func (c *Client) handleError(err error) { + c.errHandlerOnce.Do(func() { c.ErrorHandler(err) }) +} diff --git a/ws/stmt/config.go b/ws/stmt/config.go new file mode 100644 index 0000000..332ac55 --- /dev/null +++ b/ws/stmt/config.go @@ -0,0 +1,62 @@ +package stmt + +import ( + "errors" + "time" +) + +type Config struct { + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string +} + +func NewConfig(url string, chanLength uint) *Config { + return &Config{ + Url: url, + ChanLength: chanLength, + } +} +func (c *Config) SetConnectUser(user string) error { + c.User = user + return nil +} + +func (c *Config) SetConnectPass(pass string) error { + c.Password = pass + return nil +} +func (c *Config) SetConnectDB(db string) error { + c.DB = db + return nil +} + +func (c *Config) SetMessageTimeout(timeout time.Duration) error { + if timeout < time.Second { + return errors.New("message timeout cannot be less than 1 second") + } + c.MessageTimeout = timeout + return nil +} + +func (c *Config) SetWriteWait(writeWait time.Duration) error { + if writeWait < 0 { + return errors.New("write wait cannot be less than 0") + } + c.WriteWait = writeWait + return nil +} + +func (c *Config) SetErrorHandler(f func(connector *Connector, err error)) { + c.ErrorHandler = f +} + +func (c *Config) SetCloseHandler(f func()) { + c.CloseHandler = f +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go new file mode 100644 index 0000000..01ba361 --- /dev/null +++ b/ws/stmt/connector.go @@ -0,0 +1,278 @@ +package stmt + +import ( + "container/list" + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" + "github.com/taosdata/driver-go/v3/common" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +type Connector struct { + client *client.Client + requestID uint64 + listLock sync.RWMutex + sendChanList *list.List + writeTimeout time.Duration + readTimeout time.Duration + config *Config + closeOnce sync.Once + closeChan chan struct{} + customErrorHandler func(*Connector, error) + customCloseHandler func() +} + +var ( + ConnectTimeoutErr = errors.New("stmt connect timeout") +) + +func NewConnector(config *Config) (*Connector, error) { + var connector *Connector + readTimeout := common.DefaultMessageTimeout + writeTimeout := common.DefaultWriteWait + if config.MessageTimeout > 0 { + readTimeout = config.MessageTimeout + } + if config.WriteWait > 0 { + writeTimeout = config.WriteWait + } + ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + if err != nil { + return nil, err + } + defer func() { + if connector == nil { + ws.Close() + } + }() + if config.MessageTimeout <= 0 { + config.MessageTimeout = common.DefaultMessageTimeout + } + req := &ConnectReq{ + ReqID: 0, + User: config.User, + Password: config.Password, + DB: config.DB, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: STMTConnect, + Args: args, + } + connectAction, err := client.JsonI.Marshal(action) + if err != nil { + return nil, err + } + ws.SetWriteDeadline(time.Now().Add(writeTimeout)) + err = ws.WriteMessage(websocket.TextMessage, connectAction) + if err != nil { + return nil, err + } + done := make(chan struct{}) + ctx, cancel := context.WithTimeout(context.Background(), readTimeout) + var respBytes []byte + go func() { + _, respBytes, err = ws.ReadMessage() + close(done) + }() + select { + case <-done: + cancel() + case <-ctx.Done(): + cancel() + return nil, ConnectTimeoutErr + } + if err != nil { + return nil, err + } + var resp ConnectResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + wsClient := client.NewClient(ws, config.ChanLength) + wsClient.WriteWait = writeTimeout + connector = &Connector{ + client: wsClient, + requestID: 0, + listLock: sync.RWMutex{}, + sendChanList: list.New(), + writeTimeout: writeTimeout, + readTimeout: readTimeout, + config: config, + closeOnce: sync.Once{}, + closeChan: make(chan struct{}), + customErrorHandler: config.ErrorHandler, + customCloseHandler: config.CloseHandler, + } + + wsClient.TextMessageHandler = connector.handleTextMessage + wsClient.ErrorHandler = connector.handleError + go wsClient.WritePump() + go wsClient.ReadPump() + return connector, nil +} + +func (c *Connector) handleTextMessage(message []byte) { + iter := client.JsonI.BorrowIterator(message) + var reqID uint64 + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "req_id": + reqID = iter.ReadUint64() + return false + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + c.listLock.Lock() + element := c.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + c.sendChanList.Remove(element) + } + c.listLock.Unlock() +} + +type IndexedChan struct { + index uint64 + channel chan []byte +} + +func (c *Connector) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { + envelope.Type = websocket.TextMessage + return c.send(reqID, envelope) +} +func (c *Connector) sendBinary(reqID uint64, envelope *client.Envelope) ([]byte, error) { + envelope.Type = websocket.BinaryMessage + return c.send(reqID, envelope) +} +func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error) { + channel := &IndexedChan{ + index: reqID, + channel: make(chan []byte, 1), + } + element := c.addMessageOutChan(channel) + c.client.Send(envelope) + ctx, cancel := context.WithTimeout(context.Background(), c.readTimeout) + defer cancel() + select { + case <-c.closeChan: + return nil, errors.New("connection closed") + case resp := <-channel.channel: + return resp, nil + case <-ctx.Done(): + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, fmt.Errorf("message timeout :%s", envelope.Msg.String()) + } +} + +func (c *Connector) sendTextWithoutResp(envelope *client.Envelope) { + envelope.Type = websocket.TextMessage + c.client.Send(envelope) +} + +func (c *Connector) findOutChanByID(index uint64) *list.Element { + root := c.sendChanList.Front() + if root == nil { + return nil + } + rootIndex := root.Value.(*IndexedChan).index + if rootIndex == index { + return root + } + item := root.Next() + for { + if item == nil || item == root { + return nil + } + if item.Value.(*IndexedChan).index == index { + return item + } + item = item.Next() + } +} + +func (c *Connector) addMessageOutChan(outChan *IndexedChan) *list.Element { + c.listLock.Lock() + element := c.sendChanList.PushBack(outChan) + c.listLock.Unlock() + return element +} + +func (c *Connector) handleError(err error) { + if c.customErrorHandler != nil { + c.customErrorHandler(c, err) + } + c.Close() +} + +func (c *Connector) generateReqID() uint64 { + return atomic.AddUint64(&c.requestID, 1) +} + +func (c *Connector) Init() (*Stmt, error) { + reqID := c.generateReqID() + req := &InitReq{ + ReqID: reqID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: STMTInit, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp InitResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return &Stmt{ + id: resp.StmtID, + connector: c, + }, nil +} + +func (c *Connector) Close() error { + c.closeOnce.Do(func() { + close(c.closeChan) + c.client.Close() + if c.customCloseHandler != nil { + c.customCloseHandler() + } + }) + return nil +} diff --git a/ws/stmt/proto.go b/ws/stmt/proto.go new file mode 100644 index 0000000..2fed0ab --- /dev/null +++ b/ws/stmt/proto.go @@ -0,0 +1,136 @@ +package stmt + +import "encoding/json" + +const ( + SetTagsMessage = 1 + BindMessage = 2 +) + +const ( + STMTConnect = "conn" + STMTInit = "init" + STMTPrepare = "prepare" + STMTSetTableName = "set_table_name" + STMTAddBatch = "add_batch" + STMTExec = "exec" + STMTClose = "close" +) + +type ConnectReq struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` +} + +type ConnectResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +type InitReq struct { + ReqID uint64 `json:"req_id"` +} + +type InitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type PrepareReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} +type PrepareResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type SetTableNameReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Name string `json:"name"` +} + +type SetTableNameResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type SetTagsReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Tags json.RawMessage `json:"tags"` +} + +type SetTagsResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type BindReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + Columns json.RawMessage `json:"columns"` +} +type BindResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type AddBatchReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} +type AddBatchResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type ExecReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} +type ExecResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +type CloseReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go new file mode 100644 index 0000000..e3c4c74 --- /dev/null +++ b/ws/stmt/stmt.go @@ -0,0 +1,255 @@ +package stmt + +import ( + "encoding/binary" + + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +type Stmt struct { + connector *Connector + id uint64 + lastAffected int +} + +func (s *Stmt) Prepare(sql string) error { + reqID := s.connector.generateReqID() + req := &PrepareReq{ + ReqID: reqID, + StmtID: s.id, + SQL: sql, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: STMTPrepare, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return err + } + var resp PrepareResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Stmt) SetTableName(name string) error { + reqID := s.connector.generateReqID() + req := &SetTableNameReq{ + ReqID: reqID, + StmtID: s.id, + Name: name, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: STMTSetTableName, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return err + } + var resp SetTableNameResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { + tagValues := tags.GetValues() + reverseTags := make([]*param.Param, len(tagValues)) + for i := 0; i < len(tagValues); i++ { + reverseTags[i] = param.NewParam(1).AddValue(tagValues[i]) + } + block, err := serializer.SerializeRawBlock(reverseTags, bindType) + if err != nil { + return err + } + reqID := s.connector.generateReqID() + reqData := make([]byte, 24) + binary.LittleEndian.PutUint64(reqData, reqID) + binary.LittleEndian.PutUint64(reqData[8:], s.id) + binary.LittleEndian.PutUint64(reqData[16:], SetTagsMessage) + envelope := s.connector.client.GetEnvelope() + envelope.Msg.Grow(24 + len(block)) + envelope.Msg.Write(reqData) + envelope.Msg.Write(block) + respBytes, err := s.connector.sendBinary(reqID, envelope) + if err != nil { + return err + } + var resp SetTagsResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) error { + block, err := serializer.SerializeRawBlock(params, bindType) + if err != nil { + return err + } + reqID := s.connector.generateReqID() + reqData := make([]byte, 24) + binary.LittleEndian.PutUint64(reqData, reqID) + binary.LittleEndian.PutUint64(reqData[8:], s.id) + binary.LittleEndian.PutUint64(reqData[16:], BindMessage) + envelope := s.connector.client.GetEnvelope() + envelope.Msg.Grow(24 + len(block)) + envelope.Msg.Write(reqData) + envelope.Msg.Write(block) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(reqData) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.connector.sendBinary(reqID, envelope) + if err != nil { + return err + } + var resp BindResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Stmt) AddBatch() error { + reqID := s.connector.generateReqID() + req := &AddBatchReq{ + ReqID: reqID, + StmtID: s.id, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: STMTAddBatch, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return err + } + var resp AddBatchResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Stmt) Exec() error { + reqID := s.connector.generateReqID() + req := &ExecReq{ + ReqID: reqID, + StmtID: s.id, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: STMTExec, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return err + } + var resp ExecResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + s.lastAffected = resp.Affected + return nil +} + +func (s *Stmt) GetAffectedRows() int { + return s.lastAffected +} + +func (s *Stmt) Close() error { + reqID := s.connector.generateReqID() + req := &CloseReq{ + ReqID: reqID, + StmtID: s.id, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: STMTClose, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return err + } + s.connector.sendTextWithoutResp(envelope) + return nil +} diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go new file mode 100644 index 0000000..f12d58e --- /dev/null +++ b/ws/stmt/stmt_test.go @@ -0,0 +1,613 @@ +package stmt + +import ( + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +func prepareEnv() error { + var err error + steps := []string{ + "drop database if exists test_ws_stmt", + "create database test_ws_stmt", + "create table test_ws_stmt.all_json(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(t json)", + "create table test_ws_stmt.all_all(" + + "ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")" + + "tags(" + + "tts timestamp," + + "tc1 bool," + + "tc2 tinyint," + + "tc3 smallint," + + "tc4 int," + + "tc5 bigint," + + "tc6 tinyint unsigned," + + "tc7 smallint unsigned," + + "tc8 int unsigned," + + "tc9 bigint unsigned," + + "tc10 float," + + "tc11 double," + + "tc12 binary(20)," + + "tc13 nchar(20))", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop database if exists test_ws_stmt", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func doRequest(payload string) error { + body := strings.NewReader(payload) + req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:6041/rest/sql", body) + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http code: %d", resp.StatusCode) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + iter := client.JsonI.BorrowIterator(data) + code := int32(0) + desc := "" + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "code": + code = iter.ReadInt32() + case "desc": + desc = iter.ReadString() + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + if code != 0 { + return taosErrors.NewError(int(code), desc) + } + return nil +} + +func query(payload string) (*common.TDEngineRestfulResp, error) { + body := strings.NewReader(payload) + req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:6041/rest/sql", body) + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http code: %d", resp.StatusCode) + } + return marshalBody(resp.Body, 512) +} + +func TestStmt(t *testing.T) { + err := prepareEnv() + if err != nil { + t.Error(err) + return + } + defer cleanEnv() + now := time.Now() + config := NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetConnectDB("test_ws_stmt") + config.SetMessageTimeout(common.DefaultMessageTimeout) + config.SetWriteWait(common.DefaultWriteWait) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + defer connector.Close() + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + err = stmt.Prepare("insert into ? using all_json tags(?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags(param.NewParam(1).AddJson([]byte(`{"tb":1}`)), param.NewColumnType(1).AddJson(0)) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Close() + if err != nil { + t.Error(err) + return + } + result, err := query("select * from test_ws_stmt.all_json order by ts") + if err != nil { + t.Error(err) + return + } + assert.Equal(t, 0, result.Code, result) + assert.Equal(t, 3, len(result.Data)) + assert.Equal(t, 15, len(result.ColTypes)) + row1 := result.Data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, []byte(`{"tb":1}`), row1[14]) + row2 := result.Data[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"tb":1}`), row2[14]) + row3 := result.Data[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, []byte(`{"tb":1}`), row3[14]) + } + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + err = stmt.Prepare("insert into ? using all_all tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + + err = stmt.SetTableName("tb2") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags( + param.NewParam(14). + AddTimestamp(now, 0). + AddBool(true). + AddTinyint(2). + AddSmallint(2). + AddInt(2). + AddBigint(2). + AddUTinyint(2). + AddUSmallint(2). + AddUInt(2). + AddUBigint(2). + AddFloat(2). + AddDouble(2). + AddBinary([]byte("tb2")). + AddNchar("tb2"), + param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + ) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Close() + if err != nil { + t.Error(err) + return + } + result, err := query("select * from test_ws_stmt.all_all order by ts") + if err != nil { + t.Error(err) + return + } + assert.Equal(t, 3, affected) + assert.Equal(t, 0, result.Code, result) + assert.Equal(t, 3, len(result.Data)) + assert.Equal(t, 28, len(result.ColTypes)) + row1 := result.Data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row2 := result.Data[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row3 := result.Data[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, now.UnixNano()/1e6, row3[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[15]) + assert.Equal(t, int8(2), row3[16]) + assert.Equal(t, int16(2), row3[17]) + assert.Equal(t, int32(2), row3[18]) + assert.Equal(t, int64(2), row3[19]) + assert.Equal(t, uint8(2), row3[20]) + assert.Equal(t, uint16(2), row3[21]) + assert.Equal(t, uint32(2), row3[22]) + assert.Equal(t, uint64(2), row3[23]) + assert.Equal(t, float32(2), row3[24]) + assert.Equal(t, float64(2), row3[25]) + assert.Equal(t, "tb2", row3[26]) + assert.Equal(t, "tb2", row3[27]) + } +} + +func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, error) { + var result common.TDEngineRestfulResp + iter := client.JsonI.BorrowIterator(make([]byte, bufferSize)) + defer client.JsonI.ReturnIterator(iter) + iter.Reset(body) + timeFormat := time.RFC3339Nano + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "code": + result.Code = iter.ReadInt() + case "desc": + result.Desc = iter.ReadString() + case "column_meta": + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { + index := 0 + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { + switch index { + case 0: + result.ColNames = append(result.ColNames, iter.ReadString()) + index = 1 + case 1: + typeStr := iter.ReadString() + t, exist := common.NameTypeMap[typeStr] + if exist { + result.ColTypes = append(result.ColTypes, t) + } else { + iter.ReportError("unsupported type in column_meta", typeStr) + } + index = 2 + case 2: + result.ColLength = append(result.ColLength, iter.ReadInt64()) + index = 0 + } + return true + }) + return true + }) + case "data": + columnCount := len(result.ColTypes) + column := 0 + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { + column = 0 + var row = make([]driver.Value, columnCount) + iter.ReadArrayCB(func(iter *jsoniter.Iterator) bool { + defer func() { + column += 1 + }() + columnType := result.ColTypes[column] + if columnType == common.TSDB_DATA_TYPE_JSON { + row[column] = iter.SkipAndReturnBytes() + return true + } + if iter.ReadNil() { + row[column] = nil + return true + } + var err error + switch columnType { + case common.TSDB_DATA_TYPE_NULL: + iter.Skip() + row[column] = nil + case common.TSDB_DATA_TYPE_BOOL: + row[column] = iter.ReadAny().ToBool() + case common.TSDB_DATA_TYPE_TINYINT: + row[column] = iter.ReadInt8() + case common.TSDB_DATA_TYPE_SMALLINT: + row[column] = iter.ReadInt16() + case common.TSDB_DATA_TYPE_INT: + row[column] = iter.ReadInt32() + case common.TSDB_DATA_TYPE_BIGINT: + row[column] = iter.ReadInt64() + case common.TSDB_DATA_TYPE_FLOAT: + row[column] = iter.ReadFloat32() + case common.TSDB_DATA_TYPE_DOUBLE: + row[column] = iter.ReadFloat64() + case common.TSDB_DATA_TYPE_BINARY: + row[column] = iter.ReadString() + case common.TSDB_DATA_TYPE_TIMESTAMP: + b := iter.ReadString() + row[column], err = time.Parse(timeFormat, b) + if err != nil { + iter.ReportError("parse time", err.Error()) + } + case common.TSDB_DATA_TYPE_NCHAR: + row[column] = iter.ReadString() + case common.TSDB_DATA_TYPE_UTINYINT: + row[column] = iter.ReadUint8() + case common.TSDB_DATA_TYPE_USMALLINT: + row[column] = iter.ReadUint16() + case common.TSDB_DATA_TYPE_UINT: + row[column] = iter.ReadUint32() + case common.TSDB_DATA_TYPE_UBIGINT: + row[column] = iter.ReadUint64() + default: + row[column] = nil + iter.Skip() + } + return iter.Error == nil + }) + if iter.Error != nil { + return false + } + result.Data = append(result.Data, row) + return true + }) + case "rows": + result.Rows = iter.ReadInt() + default: + iter.Skip() + } + return iter.Error == nil + }) + if iter.Error != nil && iter.Error != io.EOF { + return nil, iter.Error + } + return &result, nil +} diff --git a/ws/tmq/config.go b/ws/tmq/config.go new file mode 100644 index 0000000..1441f5a --- /dev/null +++ b/ws/tmq/config.go @@ -0,0 +1,140 @@ +package tmq + +import ( + "errors" + "fmt" + "time" + + "github.com/taosdata/driver-go/v3/common/tmq" +) + +type config struct { + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + User string + Password string + GroupID string + ClientID string + OffsetRest string + AutoCommit string + AutoCommitIntervalMS string + SnapshotEnable string + WithTableName string +} + +func newConfig(url string, chanLength uint) *config { + return &config{ + Url: url, + ChanLength: chanLength, + } +} + +func (c *config) setConnectUser(user tmq.ConfigValue) error { + var ok bool + c.User, ok = user.(string) + if !ok { + return fmt.Errorf("td.connect.user requires string got %T", user) + } + return nil +} + +func (c *config) setConnectPass(pass tmq.ConfigValue) error { + var ok bool + c.Password, ok = pass.(string) + if !ok { + return fmt.Errorf("td.connect.pass requires string got %T", pass) + } + return nil +} + +func (c *config) setGroupID(groupID tmq.ConfigValue) error { + var ok bool + c.GroupID, ok = groupID.(string) + if !ok { + return fmt.Errorf("group.id requires string got %T", groupID) + } + return nil +} + +func (c *config) setClientID(clientID tmq.ConfigValue) error { + var ok bool + c.ClientID, ok = clientID.(string) + if !ok { + return fmt.Errorf("client.id requires string got %T", clientID) + } + return nil +} + +func (c *config) setAutoOffsetReset(offsetReset tmq.ConfigValue) error { + var ok bool + c.OffsetRest, ok = offsetReset.(string) + if !ok { + return fmt.Errorf("auto.offset.reset requires string got %T", offsetReset) + } + return nil +} + +func (c *config) setMessageTimeout(timeout tmq.ConfigValue) error { + var ok bool + c.MessageTimeout, ok = timeout.(time.Duration) + if !ok { + return fmt.Errorf("ws.message.timeout requires time.Duration got %T", timeout) + } + if c.MessageTimeout < time.Second { + return errors.New("ws.message.timeout cannot be less than 1 second") + } + return nil +} + +func (c *config) setWriteWait(writeWait tmq.ConfigValue) error { + var ok bool + c.WriteWait, ok = writeWait.(time.Duration) + if !ok { + return fmt.Errorf("ws.message.writeWait requires time.Duration got %T", writeWait) + } + if c.WriteWait < time.Second { + return errors.New("ws.message.writeWait cannot be less than 1 second") + } + if c.WriteWait < 0 { + return errors.New("ws.message.writeWait cannot be less than 0") + } + return nil +} + +func (c *config) setAutoCommit(enable tmq.ConfigValue) error { + var ok bool + c.AutoCommit, ok = enable.(string) + if !ok { + return fmt.Errorf("enable.auto.commit requires string got %T", enable) + } + return nil +} + +func (c *config) setAutoCommitIntervalMS(autoCommitIntervalMS tmq.ConfigValue) error { + var ok bool + c.AutoCommitIntervalMS, ok = autoCommitIntervalMS.(string) + if !ok { + return fmt.Errorf("auto.commit.interval.ms requires string got %T", autoCommitIntervalMS) + } + return nil +} + +func (c *config) setSnapshotEnable(enableSnapshot tmq.ConfigValue) error { + var ok bool + c.SnapshotEnable, ok = enableSnapshot.(string) + if !ok { + return fmt.Errorf("experimental.snapshot.enable requires string got %T", enableSnapshot) + } + return nil +} + +func (c *config) setWithTableName(withTableName tmq.ConfigValue) error { + var ok bool + c.SnapshotEnable, ok = withTableName.(string) + if !ok { + return fmt.Errorf("msg.with.table.name requires string got %T", withTableName) + } + return nil +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go new file mode 100644 index 0000000..24a5c23 --- /dev/null +++ b/ws/tmq/consumer.go @@ -0,0 +1,648 @@ +package tmq + +import ( + "container/list" + "context" + "encoding/binary" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/common/tmq" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +type Consumer struct { + client *client.Client + requestID uint64 + err error + latestMessageID uint64 + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + url string + user string + password string + groupID string + clientID string + offsetRest string + autoCommit string + autoCommitIntervalMS string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} +} + +type IndexedChan struct { + index uint64 + channel chan []byte +} + +type WSError struct { + err error +} + +func (e *WSError) Error() string { + return fmt.Sprintf("websocket close with error %s", e.err) +} + +// NewConsumer create a tmq consumer +func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { + confCopy := conf.Clone() + config, err := configMapToConfig(&confCopy) + if err != nil { + return nil, err + } + ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + if err != nil { + return nil, err + } + wsClient := client.NewClient(ws, config.ChanLength) + tmq := &Consumer{ + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: config.Url, + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: config.AutoCommit, + autoCommitIntervalMS: config.AutoCommitIntervalMS, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), + } + if config.WriteWait > 0 { + wsClient.WriteWait = config.WriteWait + } + wsClient.BinaryMessageHandler = tmq.handleBinaryMessage + wsClient.TextMessageHandler = tmq.handleTextMessage + wsClient.ErrorHandler = tmq.handleError + go wsClient.WritePump() + go wsClient.ReadPump() + return tmq, nil +} + +func configMapToConfig(m *tmq.ConfigMap) (*config, error) { + url, err := m.Get("ws.url", "") + if err != nil { + return nil, err + } + if url == "" { + return nil, errors.New("ws.url required") + } + chanLen, err := m.Get("ws.message.channelLen", uint(0)) + if err != nil { + return nil, err + } + messageTimeout, err := m.Get("ws.message.timeout", common.DefaultMessageTimeout) + if err != nil { + return nil, err + } + writeWait, err := m.Get("ws.message.writeWait", common.DefaultWriteWait) + if err != nil { + return nil, err + } + user, err := m.Get("td.connect.user", "") + if err != nil { + return nil, err + } + pass, err := m.Get("td.connect.pass", "") + if err != nil { + return nil, err + } + groupID, err := m.Get("group.id", "") + if err != nil { + return nil, err + } + clientID, err := m.Get("client.id", "") + if err != nil { + return nil, err + } + offsetReset, err := m.Get("auto.offset.reset", "") + if err != nil { + return nil, err + } + enableAutoCommit, err := m.Get("enable.auto.commit", "") + if err != nil { + return nil, err + } + //auto.commit.interval.ms + autoCommitIntervalMS, err := m.Get("auto.commit.interval.ms", "") + if err != nil { + return nil, err + } + enableSnapshot, err := m.Get("experimental.snapshot.enable", "") + if err != nil { + return nil, err + } + withTableName, err := m.Get("msg.with.table.name", "") + if err != nil { + return nil, err + } + config := newConfig(url.(string), chanLen.(uint)) + err = config.setMessageTimeout(messageTimeout.(time.Duration)) + if err != nil { + return nil, err + } + err = config.setWriteWait(writeWait.(time.Duration)) + if err != nil { + return nil, err + } + err = config.setConnectUser(user) + if err != nil { + return nil, err + } + err = config.setConnectPass(pass) + if err != nil { + return nil, err + } + err = config.setGroupID(groupID) + if err != nil { + return nil, err + } + err = config.setClientID(clientID) + if err != nil { + return nil, err + } + err = config.setAutoOffsetReset(offsetReset) + if err != nil { + return nil, err + } + err = config.setAutoCommit(enableAutoCommit) + if err != nil { + return nil, err + } + err = config.setAutoCommitIntervalMS(autoCommitIntervalMS) + if err != nil { + return nil, err + } + err = config.setSnapshotEnable(enableSnapshot) + if err != nil { + return nil, err + } + err = config.setWithTableName(withTableName) + if err != nil { + return nil, err + } + return config, nil +} + +func (c *Consumer) handleTextMessage(message []byte) { + iter := client.JsonI.BorrowIterator(message) + var reqID uint64 + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "req_id": + reqID = iter.ReadUint64() + return false + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + c.listLock.Lock() + element := c.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + c.sendChanList.Remove(element) + } + c.listLock.Unlock() +} + +func (c *Consumer) handleBinaryMessage(message []byte) { + reqID := binary.LittleEndian.Uint64(message[8:16]) + c.listLock.Lock() + element := c.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + c.sendChanList.Remove(element) + } + c.listLock.Unlock() +} + +func (c *Consumer) handleError(err error) { + c.err = &WSError{err: err} + c.Close() +} + +func (c *Consumer) generateReqID() uint64 { + return atomic.AddUint64(&c.requestID, 1) +} + +// Close consumer. This function can be called multiple times +func (c *Consumer) Close() error { + c.closeOnce.Do(func() { + close(c.closeChan) + c.client.Close() + }) + return nil +} + +func (c *Consumer) addMessageOutChan(outChan *IndexedChan) *list.Element { + c.listLock.Lock() + element := c.sendChanList.PushBack(outChan) + c.listLock.Unlock() + return element +} + +func (c *Consumer) findOutChanByID(index uint64) *list.Element { + root := c.sendChanList.Front() + if root == nil { + return nil + } + rootIndex := root.Value.(*IndexedChan).index + if rootIndex == index { + return root + } + item := root.Next() + for { + if item == nil || item == root { + return nil + } + if item.Value.(*IndexedChan).index == index { + return item + } + item = item.Next() + } +} + +const ( + TMQSubscribe = "subscribe" + TMQPoll = "poll" + TMQFetch = "fetch" + TMQFetchBlock = "fetch_block" + TMQFetchJsonMeta = "fetch_json_meta" + TMQCommit = "commit" + TMQUnsubscribe = "unsubscribe" +) + +var ClosedErr = errors.New("connection closed") + +func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { + if !c.client.IsRunning() { + c.client.PutEnvelope(envelope) + return nil, ClosedErr + } + channel := &IndexedChan{ + index: reqID, + channel: make(chan []byte, 1), + } + element := c.addMessageOutChan(channel) + envelope.Type = websocket.TextMessage + c.client.Send(envelope) + ctx, cancel := context.WithTimeout(context.Background(), c.messageTimeout) + defer cancel() + select { + case <-c.closeChan: + return nil, ClosedErr + case resp := <-channel.channel: + return resp, nil + case <-ctx.Done(): + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, fmt.Errorf("message timeout :%s", envelope.Msg.String()) + } +} + +type RebalanceCb func(*Consumer, tmq.Event) error + +func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error { + return c.SubscribeTopics([]string{topic}, rebalanceCb) +} + +func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error { + if c.err != nil { + return c.err + } + reqID := c.generateReqID() + req := &SubscribeReq{ + ReqID: reqID, + User: c.user, + Password: c.password, + GroupID: c.groupID, + ClientID: c.clientID, + OffsetRest: c.offsetRest, + Topics: topics, + AutoCommit: c.autoCommit, + AutoCommitIntervalMS: c.autoCommitIntervalMS, + SnapshotEnable: c.snapshotEnable, + WithTableName: c.withTableName, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: TMQSubscribe, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return err + } + var resp SubscribeResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +// Poll messages +func (c *Consumer) Poll(timeoutMs int) tmq.Event { + if c.err != nil { + panic(c.err) + } + reqID := c.generateReqID() + req := &PollReq{ + ReqID: reqID, + BlockingTime: int64(timeoutMs), + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + action := &client.WSAction{ + Action: TMQPoll, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + var resp PollResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + if resp.Code != 0 { + panic(taosErrors.NewError(resp.Code, resp.Message)) + } + c.latestMessageID = resp.MessageID + if resp.HaveMessage { + switch resp.MessageType { + case common.TMQ_RES_DATA: + result := &tmq.DataMessage{} + result.SetDbName(resp.Database) + result.SetTopic(resp.Topic) + data, err := c.fetch(resp.MessageID) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + result.SetData(data) + return result + case common.TMQ_RES_TABLE_META: + result := &tmq.MetaMessage{} + result.SetDbName(resp.Database) + result.SetTopic(resp.Topic) + meta, err := c.fetchJsonMeta(resp.MessageID) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + result.SetMeta(meta) + return result + case common.TMQ_RES_METADATA: + result := &tmq.MetaDataMessage{} + result.SetDbName(resp.Database) + result.SetTopic(resp.Topic) + meta, err := c.fetchJsonMeta(resp.MessageID) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + data, err := c.fetch(resp.MessageID) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + result.SetMetaData(&tmq.MetaData{ + Meta: meta, + Data: data, + }) + return result + default: + return tmq.NewTMQErrorWithErr(err) + } + } else { + return nil + } +} + +func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { + reqID := c.generateReqID() + req := &FetchJsonMetaReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetchJsonMeta, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp FetchJsonMetaResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + var meta tmq.Meta + err = client.JsonI.Unmarshal(resp.Data, &meta) + if err != nil { + return nil, err + } + return &meta, nil +} + +func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { + var tmqData []*tmq.Data + for { + reqID := c.generateReqID() + req := &FetchReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetch, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp FetchResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + if resp.Completed { + break + } + // fetch block + { + req := &FetchBlockReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetchBlock, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + block := respBytes[24:] + data := parser.ReadBlock(unsafe.Pointer(*(*uintptr)(unsafe.Pointer(&block))), resp.Rows, resp.FieldsTypes, resp.Precision) + tmqData = append(tmqData, &tmq.Data{ + TableName: resp.TableName, + Data: data, + }) + } + } + return tmqData, nil +} + +func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { + return c.doCommit(c.latestMessageID) +} + +func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { + if c.err != nil { + return nil, c.err + } + reqID := c.generateReqID() + req := &CommitReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQCommit, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp CommitResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return nil, nil +} + +func (c *Consumer) Unsubscribe() error { + if c.err != nil { + return c.err + } + reqID := c.generateReqID() + req := &UnsubscribeReq{ + ReqID: reqID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: TMQUnsubscribe, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return err + } + var resp CommitResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go new file mode 100644 index 0000000..3b948b5 --- /dev/null +++ b/ws/tmq/consumer_test.go @@ -0,0 +1,284 @@ +package tmq + +import ( + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/tmq" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +func prepareEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_topic", + "drop database if exists test_ws_tmq", + "create database test_ws_tmq WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_topic with meta as database test_ws_tmq", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_topic", + "drop database if exists test_ws_tmq", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func doRequest(payload string) error { + body := strings.NewReader(payload) + req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:6041/rest/sql", body) + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http code: %d", resp.StatusCode) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + iter := client.JsonI.BorrowIterator(data) + code := int32(0) + desc := "" + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "code": + code = iter.ReadInt32() + case "desc": + desc = iter.ReadString() + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + if code != 0 { + return taosErrors.NewError(int(code), desc) + } + return nil +} + +func TestConsumer(t *testing.T) { + err := prepareEnv() + if err != nil { + t.Error(err) + return + } + defer cleanEnv() + now := time.Now() + go func() { + err = doRequest("create table test_ws_tmq.t_all(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + err = doRequest(fmt.Sprintf("insert into test_ws_tmq.t_all values('%s',true,2,3,4,5,6,7,8,9,10.123,11.123,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + }() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "5000", + "experimental.snapshot.enable": "true", + "msg.with.table.name": "true", + }) + if err != nil { + t.Error(err) + return + } + defer consumer.Close() + topic := []string{"test_ws_tmq_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + gotMeta := false + gotData := false + for i := 0; i < 5; i++ { + if gotData && gotMeta { + return + } + ev := consumer.Poll(0) + if ev != nil { + switch e := ev.(type) { + case *tmq.DataMessage: + gotData = true + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_ws_tmq", e.DBName()) + assert.Equal(t, 1, len(data)) + assert.Equal(t, "t_all", data[0].TableName) + assert.Equal(t, 1, len(data[0].Data)) + assert.Equal(t, now.Unix(), data[0].Data[0][0].(time.Time).Unix()) + var v = data[0].Data[0] + assert.Equal(t, true, v[1].(bool)) + assert.Equal(t, int8(2), v[2].(int8)) + assert.Equal(t, int16(3), v[3].(int16)) + assert.Equal(t, int32(4), v[4].(int32)) + assert.Equal(t, int64(5), v[5].(int64)) + assert.Equal(t, uint8(6), v[6].(uint8)) + assert.Equal(t, uint16(7), v[7].(uint16)) + assert.Equal(t, uint32(8), v[8].(uint32)) + assert.Equal(t, uint64(9), v[9].(uint64)) + assert.Equal(t, float32(10.123), v[10].(float32)) + assert.Equal(t, float64(11.123), v[11].(float64)) + assert.Equal(t, "binary", v[12].(string)) + assert.Equal(t, "nchar", v[13].(string)) + case *tmq.MetaMessage: + gotMeta = true + meta := e.Value().(*tmq.Meta) + assert.Equal(t, "test_ws_tmq", e.DBName()) + assert.Equal(t, "create", meta.Type) + assert.Equal(t, "t_all", meta.TableName) + assert.Equal(t, "normal", meta.TableType) + assert.Equal(t, []*tmq.Column{ + { + Name: "ts", + Type: 9, + Length: 0, + }, + { + Name: "c1", + Type: 1, + Length: 0, + }, + { + Name: "c2", + Type: 2, + Length: 0, + }, + { + Name: "c3", + Type: 3, + Length: 0, + }, + { + Name: "c4", + Type: 4, + Length: 0, + }, + { + Name: "c5", + Type: 5, + Length: 0, + }, + { + Name: "c6", + Type: 11, + Length: 0, + }, + { + Name: "c7", + Type: 12, + Length: 0, + }, + { + Name: "c8", + Type: 13, + Length: 0, + }, + { + Name: "c9", + Type: 14, + Length: 0, + }, + { + Name: "c10", + Type: 6, + Length: 0, + }, + { + Name: "c11", + Type: 7, + Length: 0, + }, + { + Name: "c12", + Type: 8, + Length: 20, + }, + { + Name: "c13", + Type: 10, + Length: 20, + }}, meta.Columns) + case tmq.Error: + t.Error(e) + return + default: + t.Error("unexpected", e) + return + } + _, err = consumer.Commit() + } + + if err != nil { + t.Error(err) + return + } + } + if !gotMeta { + t.Error("no meta got") + } + if !gotData { + t.Error("no data got") + } + err = consumer.Unsubscribe() + if err != nil { + t.Error(err) + return + } +} diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go new file mode 100644 index 0000000..a5376eb --- /dev/null +++ b/ws/tmq/proto.go @@ -0,0 +1,113 @@ +package tmq + +import "encoding/json" + +type SubscribeReq struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` + GroupID string `json:"group_id"` + ClientID string `json:"client_id"` + OffsetRest string `json:"offset_rest"` + Topics []string `json:"topics"` + AutoCommit string `json:"auto_commit"` + AutoCommitIntervalMS string `json:"auto_commit_interval_ms"` + SnapshotEnable string `json:"snapshot_enable"` + WithTableName string `json:"with_table_name"` +} + +type SubscribeResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +type PollReq struct { + ReqID uint64 `json:"req_id"` + BlockingTime int64 `json:"blocking_time"` +} + +type PollResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + HaveMessage bool `json:"have_message"` + Topic string `json:"topic"` + Database string `json:"database"` + VgroupID int32 `json:"vgroup_id"` + MessageType int32 `json:"message_type"` + MessageID uint64 `json:"message_id"` +} + +type FetchJsonMetaReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +} + +type FetchJsonMetaResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + MessageID uint64 `json:"message_id"` + Data json.RawMessage `json:"data"` +} + +type FetchReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +} + +type FetchResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + MessageID uint64 `json:"message_id"` + Completed bool `json:"completed"` + TableName string `json:"table_name"` + Rows int `json:"rows"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +type FetchBlockReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +} + +type CommitReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +} + +type CommitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + MessageID uint64 `json:"message_id"` +} + +type UnsubscribeReq struct { + ReqID uint64 `json:"req_id"` +} + +type UnsubscribeResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +}