diff --git a/.github/workflows/compatibility.yml b/.github/workflows/compatibility.yml new file mode 100644 index 0000000..956433b --- /dev/null +++ b/.github/workflows/compatibility.yml @@ -0,0 +1,140 @@ +name: compatibility + +on: + pull_request: + branches: + - '3.1' + +jobs: + build: + runs-on: ubuntu-22.04 + strategy: + matrix: + td_version: [ 'main', '3.0' ] + name: Build ${{ matrix.td_version }} + outputs: + commit_id: ${{ steps.get_commit_id.outputs.commit_id }} + steps: + - name: checkout TDengine by pr + if: github.event_name == 'pull_request' + uses: actions/checkout@v3 + with: + repository: 'taosdata/TDengine' + path: 'TDengine' + ref: ${{ matrix.td_version }} + + - name: get_commit_id + id: get_commit_id + run: | + cd TDengine + echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT + + + - name: Cache server by pr + if: github.event_name == 'pull_request' + id: cache-server-pr + uses: actions/cache@v3 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ matrix.td_version }}-${{ steps.get_commit_id.outputs.commit_id }} + + - name: prepare install + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: sudo apt install -y libgeos-dev + + - name: install TDengine + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: | + cd TDengine + mkdir debug + cd debug + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 + make -j 4 + + - name: package + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: | + mkdir -p ./release + cp ./TDengine/debug/build/bin/taos ./release/ + cp ./TDengine/debug/build/bin/taosd ./release/ + cp ./TDengine/tools/taosadapter/taosadapter ./release/ + cp ./TDengine/debug/build/lib/libtaos.so.3.9.9.9 ./release/ + cp ./TDengine/debug/build/lib/librocksdb.so.8.1.1 ./release/ ||: + cp ./TDengine/include/client/taos.h ./release/ + cat >./release/install.sh<start.sh<> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server by pr if: github.event_name == 'pull_request' @@ -77,7 +74,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package @@ -114,7 +111,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.14', '1.19' ] + go: [ '1.14', 'stable' ] name: Go ${{ matrix.go }} steps: - name: get cache server by pr @@ -137,11 +134,6 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ inputs.tbBranch }}- - - name: checkout - uses: actions/checkout@v3 - with: - path: 'driver-go' - - name: prepare install run: sudo apt install -y libgeos-dev @@ -150,6 +142,15 @@ jobs: tar -zxvf server.tar.gz cd release && sudo sh install.sh + - name: checkout + uses: actions/checkout@v3 + + - name: copy taos cfg + run: | + sudo mkdir -p /etc/taos + sudo cp ./.github/workflows/taos.cfg /etc/taos/taos.cfg + sudo cp ./.github/workflows/taosadapter.toml /etc/taos/taosadapter.toml + - name: shell run: | cat >start.sh<> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server id: cache-server @@ -43,8 +40,6 @@ jobs: if: steps.cache-server.outputs.cache-hit != 'true' run: sudo apt install -y libgeos-dev - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: install TDengine if: steps.cache-server.outputs.cache-hit != 'true' @@ -52,7 +47,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package @@ -87,7 +82,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.14', '1.19' ] + go: [ '1.14', 'stable' ] name: Go ${{ matrix.go }} steps: - name: get cache server @@ -99,11 +94,6 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ github.ref_name }}- - - name: checkout - uses: actions/checkout@v3 - with: - path: 'driver-go' - - name: prepare install run: sudo apt install -y libgeos-dev @@ -112,6 +102,15 @@ jobs: tar -zxvf server.tar.gz cd release && sudo sh install.sh + - name: checkout + uses: actions/checkout@v3 + + - name: copy taos cfg + run: | + sudo mkdir -p /etc/taos + sudo cp ./.github/workflows/taos.cfg /etc/taos/taos.cfg + sudo cp ./.github/workflows/taosadapter.toml /etc/taos/taosadapter.toml + - name: shell run: | cat >start.sh<= rs.blockSize { if err := rs.taosFetchBlock(); err != nil { return err } @@ -67,16 +67,6 @@ func (rs *rows) Next(dest []driver.Value) error { return io.EOF } - if rs.blockOffset >= rs.blockSize { - if err := rs.taosFetchBlock(); err != nil { - return err - } - } - if rs.blockSize == 0 { - rs.block = nil - rs.freeResult() - return io.EOF - } parser.ReadRow(dest, rs.block, rs.blockSize, rs.blockOffset, rs.rowsHeader.ColTypes, rs.precision) rs.blockOffset++ return nil @@ -111,9 +101,11 @@ func (rs *rows) asyncFetchRows() *handler.AsyncResult { func (rs *rows) freeResult() { if rs.result != nil { - locker.Lock() - wrapper.TaosFreeResult(rs.result) - locker.Unlock() + if !rs.isStmt { + locker.Lock() + wrapper.TaosFreeResult(rs.result) + locker.Unlock() + } rs.result = nil } diff --git a/af/stmt.go b/af/stmt.go index 18d52c2..dbe849b 100644 --- a/af/stmt.go +++ b/af/stmt.go @@ -2,9 +2,11 @@ package af import "C" import ( + "database/sql/driver" "fmt" "unsafe" + "github.com/taosdata/driver-go/v3/af/async" "github.com/taosdata/driver-go/v3/af/locker" "github.com/taosdata/driver-go/v3/common/param" taosError "github.com/taosdata/driver-go/v3/errors" @@ -36,22 +38,22 @@ func (s *Stmt) Prepare(sql string) error { code := wrapper.TaosStmtPrepare(s.stmt, sql) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } isInsert, code := wrapper.TaosStmtIsInsert(s.stmt) if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } s.isInsert = isInsert + return nil +} + +func (s *Stmt) NumParams() (int, error) { numParams, code := wrapper.TaosStmtNumParams(s.stmt) if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return 0, s.stmtErr(code) } - s.paramCount = numParams - return nil + return numParams, nil } func (s *Stmt) SetTableNameWithTags(tableName string, tags *param.Param) error { @@ -59,8 +61,7 @@ func (s *Stmt) SetTableNameWithTags(tableName string, tags *param.Param) error { code := wrapper.TaosStmtSetTBNameTags(s.stmt, tableName, tags.GetValues()) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } @@ -70,36 +71,33 @@ func (s *Stmt) SetTableName(tableName string) error { code := wrapper.TaosStmtSetTBName(s.stmt, tableName) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } func (s *Stmt) BindRow(row *param.Param) error { - if s.paramCount == 0 { - locker.Lock() - code := wrapper.TaosStmtBindParam(s.stmt, nil) - locker.Unlock() - if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + if s.isInsert { + if s.paramCount == 0 { + paramCount, err := s.NumParams() + if err != nil { + return err + } + s.paramCount = paramCount } - return nil } if row == nil { return fmt.Errorf("row param got nil") } value := row.GetValues() - if len(value) != s.paramCount { + if s.isInsert && len(value) != s.paramCount { return fmt.Errorf("row param count error : expect %d got %d", s.paramCount, len(value)) } locker.Lock() code := wrapper.TaosStmtBindParam(s.stmt, value) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } @@ -116,8 +114,7 @@ func (s *Stmt) AddBatch() error { code := wrapper.TaosStmtAddBatch(s.stmt) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } @@ -127,20 +124,45 @@ func (s *Stmt) Execute() error { code := wrapper.TaosStmtExecute(s.stmt) locker.Unlock() if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } +func (s *Stmt) UseResult() (driver.Rows, error) { + locker.Lock() + res := wrapper.TaosStmtUseResult(s.stmt) + locker.Unlock() + numFields := wrapper.TaosNumFields(res) + rowsHeader, err := wrapper.ReadColumn(res, numFields) + h := async.GetHandler() + if err != nil { + async.PutHandler(h) + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + rs := &rows{ + handler: h, + rowsHeader: rowsHeader, + result: res, + precision: precision, + isStmt: true, + } + return rs, nil +} + func (s *Stmt) Close() error { locker.Lock() code := wrapper.TaosStmtClose(s.stmt) locker.Unlock() s.stmt = nil if code != 0 { - errStr := wrapper.TaosStmtErrStr(s.stmt) - return taosError.NewError(code, errStr) + return s.stmtErr(code) } return nil } + +func (s *Stmt) stmtErr(code int) error { + errStr := wrapper.TaosStmtErrStr(s.stmt) + return taosError.NewError(code, errStr) +} diff --git a/af/stmt_test.go b/af/stmt_test.go new file mode 100644 index 0000000..6a23e9f --- /dev/null +++ b/af/stmt_test.go @@ -0,0 +1,55 @@ +package af + +import ( + "database/sql/driver" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" +) + +func TestNewStmt(t *testing.T) { + db := testDatabase(t) + _, err := db.Exec("create table test_stmt (ts timestamp,v int)") + assert.NoError(t, err) + stmt := db.Stmt() + err = stmt.Prepare("insert into ? values(?,?)") + assert.NoError(t, err) + err = stmt.SetTableName("test_stmt") + assert.NoError(t, err) + ts := time.Now().UnixNano() / 1e3 + err = stmt.BindRow(param.NewParam(2).AddTimestamp(time.Unix(0, ts*1e3), common.PrecisionMicroSecond).AddInt(1)) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Execute() + assert.NoError(t, err) + affected := stmt.GetAffectedRows() + assert.Equal(t, int(1), affected) + err = stmt.Prepare("select * from test_stmt where v = ?") + assert.NoError(t, err) + err = stmt.BindRow(param.NewParam(1).AddInt(1)) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Execute() + assert.NoError(t, err) + rows, err := stmt.UseResult() + assert.NoError(t, err) + dest := make([]driver.Value, 2) + err = rows.Next(dest) + assert.NoError(t, err) + assert.Equal(t, ts, dest[0].(time.Time).UnixNano()/1e3) + assert.Equal(t, int32(1), dest[1].(int32)) + err = rows.Next(dest) + assert.ErrorIs(t, err, io.EOF) + err = rows.Close() + assert.NoError(t, err) + err = stmt.Close() + assert.NoError(t, err) + err = db.Close() + assert.NoError(t, err) +} diff --git a/af/tmq/consumer.go b/af/tmq/consumer.go index ef2338e..aab447c 100644 --- a/af/tmq/consumer.go +++ b/af/tmq/consumer.go @@ -13,7 +13,8 @@ import ( ) type Consumer struct { - cConsumer unsafe.Pointer + cConsumer unsafe.Pointer + dataParser *parser.TMQRawDataParser } // NewConsumer Create new TMQ consumer with TMQ config @@ -28,7 +29,8 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { return nil, err } consumer := &Consumer{ - cConsumer: cConsumer, + cConsumer: cConsumer, + dataParser: parser.NewTMQRawDataParser(), } return consumer, nil } @@ -63,14 +65,12 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err for _, topic := range topics { errCode := wrapper.TMQListAppend(topicList, topic) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - return taosError.NewError(int(errCode), errStr) + return c.tmqError(errCode) } } errCode := wrapper.TMQSubscribe(c.cConsumer, topicList) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - return taosError.NewError(int(errCode), errStr) + return c.tmqError(errCode) } return nil } @@ -79,8 +79,7 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err func (c *Consumer) Unsubscribe() error { errCode := wrapper.TMQUnsubscribe(c.cConsumer) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return taosError.NewError(int(errCode), errStr) + return c.tmqError(errCode) } return nil } @@ -176,27 +175,22 @@ func (c *Consumer) getMeta(message unsafe.Pointer) (*tmq.Meta, error) { } func (c *Consumer) getData(message unsafe.Pointer) ([]*tmq.Data, error) { + errCode, raw := wrapper.TMQGetRaw(message) + if errCode != taosError.SUCCESS { + errStr := wrapper.TaosErrorStr(message) + err := taosError.NewError(int(errCode), errStr) + return nil, err + } + _, _, rawPtr := wrapper.ParseRawMeta(raw) + blockInfos, err := c.dataParser.Parse(rawPtr) + if err != nil { + return nil, err + } var tmqData []*tmq.Data - for { - blockSize, errCode, block := wrapper.TaosFetchRawBlock(message) - if errCode != int(taosError.SUCCESS) { - errStr := wrapper.TaosErrorStr(message) - err := taosError.NewError(errCode, errStr) - return nil, err - } - if blockSize == 0 { - break - } - tableName := wrapper.TMQGetTableName(message) - fileCount := wrapper.TaosNumFields(message) - rh, err := wrapper.ReadColumn(message, fileCount) - if err != nil { - return nil, err - } - precision := wrapper.TaosResultPrecision(message) + for i := 0; i < len(blockInfos); i++ { tmqData = append(tmqData, &tmq.Data{ - TableName: tableName, - Data: parser.ReadBlock(block, blockSize, rh.ColTypes, precision), + TableName: blockInfos[i].TableName, + Data: parser.ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision), }) } return tmqData, nil @@ -205,8 +199,7 @@ func (c *Consumer) getData(message unsafe.Pointer) ([]*tmq.Data, error) { func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { errCode := wrapper.TMQCommitSync(c.cConsumer, nil) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return nil, taosError.NewError(int(errCode), errStr) + return nil, c.tmqError(errCode) } partitions, err := c.Assignment() if err != nil { @@ -218,8 +211,7 @@ func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { func (c *Consumer) doCommit(message unsafe.Pointer) ([]tmq.TopicPartition, error) { errCode := wrapper.TMQCommitSync(c.cConsumer, message) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return nil, taosError.NewError(int(errCode), errStr) + return nil, c.tmqError(errCode) } return nil, nil } @@ -227,8 +219,7 @@ func (c *Consumer) doCommit(message unsafe.Pointer) ([]tmq.TopicPartition, error func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { errCode, list := wrapper.TMQSubscription(c.cConsumer) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return nil, taosError.NewError(int(errCode), errStr) + return nil, c.tmqError(errCode) } defer wrapper.TMQListDestroy(list) size := wrapper.TMQListGetSize(list) @@ -236,8 +227,7 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { for _, topic := range topics { errCode, assignment := wrapper.TMQGetTopicAssignment(c.cConsumer, topic) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return nil, taosError.NewError(int(errCode), errStr) + return nil, c.tmqError(errCode) } for i := 0; i < len(assignment); i++ { topicName := topic @@ -254,8 +244,7 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) error { errCode := wrapper.TMQOffsetSeek(c.cConsumer, *partition.Topic, partition.Partition, int64(partition.Offset)) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return taosError.NewError(int(errCode), errStr) + return c.tmqError(errCode) } return nil } @@ -266,7 +255,7 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of cOffset := wrapper.TMQCommitted(c.cConsumer, *partitions[i].Topic, partitions[i].Partition) offset := tmq.Offset(cOffset) if !offset.Valid() { - return nil, taosError.NewError(int(offset), wrapper.TMQErr2Str(int32(offset))) + return nil, c.tmqError(int32(offset)) } offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, @@ -281,8 +270,7 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti for i := 0; i < len(offsets); i++ { errCode := wrapper.TMQCommitOffsetSync(c.cConsumer, *offsets[i].Topic, offsets[i].Partition, int64(offsets[i].Offset)) if errCode != taosError.SUCCESS { - errStr := wrapper.TMQErr2Str(errCode) - return nil, taosError.NewError(int(errCode), errStr) + return nil, c.tmqError(errCode) } } return c.Committed(offsets, 0) @@ -293,7 +281,7 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi for i := 0; i < len(partitions); i++ { position := wrapper.TMQPosition(c.cConsumer, *partitions[i].Topic, partitions[i].Partition) if position < 0 { - return nil, taosError.NewError(int(position), wrapper.TMQErr2Str(int32(position))) + return nil, c.tmqError(int32(position)) } offsets[i] = tmq.TopicPartition{ Topic: partitions[i].Topic, @@ -308,8 +296,12 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi func (c *Consumer) Close() error { errCode := wrapper.TMQConsumerClose(c.cConsumer) if errCode != 0 { - errStr := wrapper.TMQErr2Str(errCode) - return taosError.NewError(int(errCode), errStr) + return c.tmqError(errCode) } return nil } + +func (c *Consumer) tmqError(errCode int32) error { + errStr := wrapper.TMQErr2Str(errCode) + return taosError.NewError(int(errCode), errStr) +} diff --git a/af/tmq/consumer_test.go b/af/tmq/consumer_test.go index 4dd024c..47252b7 100644 --- a/af/tmq/consumer_test.go +++ b/af/tmq/consumer_test.go @@ -68,15 +68,14 @@ func TestTmq(t *testing.T) { assert.NoError(t, err) consumer, err := NewConsumer(&tmq.ConfigMap{ - "group.id": "test", - "auto.offset.reset": "earliest", - "td.connect.ip": "127.0.0.1", - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "td.connect.port": "6030", - "client.id": "test_tmq_c", - "enable.auto.commit": "false", - //"experimental.snapshot.enable": "true", + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_c", + "enable.auto.commit": "false", "msg.with.table.name": "true", }) if err != nil { @@ -181,7 +180,7 @@ func TestSeek(t *testing.T) { } defer func() { - //execWithoutResult(conn, "drop database if exists "+db) + execWithoutResult(conn, "drop database if exists "+db) }() for _, sql := range sqls { err = execWithoutResult(conn, sql) @@ -309,3 +308,188 @@ func execWithoutResult(conn unsafe.Pointer, sql string) error { } return nil } + +func prepareMultiBlockEnv(conn unsafe.Pointer) error { + var err error + steps := []string{ + "drop topic if exists test_tmq_multi_block_topic", + "drop database if exists test_tmq_multi_block", + "create database test_tmq_multi_block vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_tmq_multi_block_topic as database test_tmq_multi_block", + "create table test_tmq_multi_block.t1(ts timestamp,v int)", + "create table test_tmq_multi_block.t2(ts timestamp,v int)", + "create table test_tmq_multi_block.t3(ts timestamp,v int)", + "create table test_tmq_multi_block.t4(ts timestamp,v int)", + "create table test_tmq_multi_block.t5(ts timestamp,v int)", + "create table test_tmq_multi_block.t6(ts timestamp,v int)", + "create table test_tmq_multi_block.t7(ts timestamp,v int)", + "create table test_tmq_multi_block.t8(ts timestamp,v int)", + "create table test_tmq_multi_block.t9(ts timestamp,v int)", + "create table test_tmq_multi_block.t10(ts timestamp,v int)", + "insert into test_tmq_multi_block.t1 values (now,1) test_tmq_multi_block.t2 values (now,2) " + + "test_tmq_multi_block.t3 values (now,3) test_tmq_multi_block.t4 values (now,4)" + + "test_tmq_multi_block.t5 values (now,5) test_tmq_multi_block.t6 values (now,6)" + + "test_tmq_multi_block.t7 values (now,7) test_tmq_multi_block.t8 values (now,8)" + + "test_tmq_multi_block.t9 values (now,9) test_tmq_multi_block.t10 values (now,10)", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func cleanMultiBlockEnv(conn unsafe.Pointer) error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_tmq_multi_block_topic", + "drop database if exists test_tmq_multi_block", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func TestMultiBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer wrapper.TaosClose(conn) + err = prepareMultiBlockEnv(conn) + assert.NoError(t, err) + defer cleanMultiBlockEnv(conn) + consumer, err := NewConsumer(&tmq.ConfigMap{ + "group.id": "test", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "auto.offset.reset": "earliest", + "client.id": "test_tmq_multi_block_topic", + "enable.auto.commit": "false", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_tmq_multi_block_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_tmq_multi_block", e.DBName()) + assert.Equal(t, 10, len(data)) + return + } + } +} + +func prepareMetaEnv(conn unsafe.Pointer) error { + var err error + steps := []string{ + "drop topic if exists test_tmq_meta_topic", + "drop database if exists test_tmq_meta", + "create database test_tmq_meta vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_tmq_meta_topic with meta as database test_tmq_meta", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func cleanMetaEnv(conn unsafe.Pointer) error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_tmq_meta_topic", + "drop database if exists test_tmq_meta", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func TestMeta(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer wrapper.TaosClose(conn) + err = prepareMetaEnv(conn) + assert.NoError(t, err) + defer cleanMetaEnv(conn) + consumer, err := NewConsumer(&tmq.ConfigMap{ + "group.id": "test", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "auto.offset.reset": "earliest", + "client.id": "test_tmq_multi_block_topic", + "enable.auto.commit": "false", + "msg.with.table.name": "true", + }) + err = consumer.Subscribe("test_tmq_meta_topic", nil) + assert.NoError(t, err) + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + go func() { + execWithoutResult(conn, "create table test_tmq_meta.st(ts timestamp,v int) tags (cn binary(20))") + execWithoutResult(conn, "create table test_tmq_meta.t1 using test_tmq_meta.st tags ('t1')") + execWithoutResult(conn, "insert into test_tmq_meta.t1 values (now,1)") + execWithoutResult(conn, "insert into test_tmq_meta.t2 using test_tmq_meta.st tags ('t1') values (now,2)") + time.Sleep(time.Second) + execWithoutResult(conn, "insert into test_tmq_meta.t1 values (now,1)") + execWithoutResult(conn, "insert into test_tmq_meta.t1 values (now,1)") + }() + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + t.Log(e) + assert.Equal(t, "test_tmq_meta", e.DBName()) + case *tmq.MetaDataMessage: + assert.Equal(t, "test_tmq_meta", e.DBName()) + assert.Equal(t, "test_tmq_meta_topic", e.Topic()) + t.Log(e) + case *tmq.MetaMessage: + assert.Equal(t, "test_tmq_meta", e.DBName()) + t.Log(e) + } + } +} diff --git a/common/datatype.go b/common/datatype.go index ea85688..b5406ed 100644 --- a/common/datatype.go +++ b/common/datatype.go @@ -1,7 +1,6 @@ package common import ( - "errors" "reflect" ) @@ -288,12 +287,3 @@ var NameTypeMap = map[string]int{ TSDB_DATA_TYPE_VARBINARY_Str: TSDB_DATA_TYPE_VARBINARY, TSDB_DATA_TYPE_GEOMETRY_Str: TSDB_DATA_TYPE_GEOMETRY, } - -var NotSupportType = errors.New("not support type") - -func GetColType(colType int) (*DBType, error) { - if colType > len(allType) || colType < 0 { - return nil, NotSupportType - } - return allType[colType], nil -} diff --git a/common/param/column.go b/common/param/column.go index de5dedc..1542f70 100644 --- a/common/param/column.go +++ b/common/param/column.go @@ -16,6 +16,10 @@ func NewColumnType(size int) *ColumnType { return &ColumnType{size: size, value: make([]*types.ColumnType, size)} } +func NewColumnTypeWithValue(value []*types.ColumnType) *ColumnType { + return &ColumnType{size: len(value), value: value, column: len(value)} +} + func (c *ColumnType) AddBool() *ColumnType { if c.column >= c.size { return c diff --git a/common/param/column_test.go b/common/param/column_test.go new file mode 100644 index 0000000..a684a8c --- /dev/null +++ b/common/param/column_test.go @@ -0,0 +1,435 @@ +package param + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/types" +) + +func TestColumnType_AddBool(t *testing.T) { + colType := NewColumnType(1) + colType.AddBool() + + expected := []*types.ColumnType{ + { + Type: types.TaosBoolType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBool() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddTinyint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddTinyint() + + expected := []*types.ColumnType{ + { + Type: types.TaosTinyintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddTinyint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddSmallint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddSmallint() + + expected := []*types.ColumnType{ + { + Type: types.TaosSmallintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddSmallint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddInt(t *testing.T) { + colType := NewColumnType(1) + + colType.AddInt() + + expected := []*types.ColumnType{ + { + Type: types.TaosIntType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddInt() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddBigint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddBigint() + + expected := []*types.ColumnType{ + { + Type: types.TaosBigintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBigint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUTinyint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUTinyint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUTinyintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUTinyint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUSmallint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUSmallint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUSmallintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUSmallint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUInt(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUInt() + + expected := []*types.ColumnType{ + { + Type: types.TaosUIntType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUInt() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddUBigint(t *testing.T) { + colType := NewColumnType(1) + + colType.AddUBigint() + + expected := []*types.ColumnType{ + { + Type: types.TaosUBigintType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddUBigint() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddFloat(t *testing.T) { + colType := NewColumnType(1) + + colType.AddFloat() + + expected := []*types.ColumnType{ + { + Type: types.TaosFloatType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddFloat() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddDouble(t *testing.T) { + colType := NewColumnType(1) + + colType.AddDouble() + + expected := []*types.ColumnType{ + { + Type: types.TaosDoubleType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddDouble() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddBinary(t *testing.T) { + colType := NewColumnType(1) + + colType.AddBinary(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosBinaryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddBinary(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddVarBinary(t *testing.T) { + colType := NewColumnType(1) + + colType.AddVarBinary(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosVarBinaryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddVarBinary(50) + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddNchar(t *testing.T) { + colType := NewColumnType(1) + + colType.AddNchar(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosNcharType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddNchar(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddTimestamp(t *testing.T) { + colType := NewColumnType(1) + + colType.AddTimestamp() + + expected := []*types.ColumnType{ + { + Type: types.TaosTimestampType, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddTimestamp() + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddJson(t *testing.T) { + colType := NewColumnType(1) + + colType.AddJson(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosJsonType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddJson(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_AddGeometry(t *testing.T) { + colType := NewColumnType(1) + + colType.AddGeometry(100) + + expected := []*types.ColumnType{ + { + Type: types.TaosGeometryType, + MaxLen: 100, + }, + } + + values, err := colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) + + colType.AddGeometry(50) + + values, err = colType.GetValue() + assert.NoError(t, err) + assert.Equal(t, expected, values) +} + +func TestColumnType_GetValue(t *testing.T) { + // Initialize ColumnType with size 3 + colType := NewColumnType(3) + + // Add column types + colType.AddBool() + colType.AddTinyint() + colType.AddFloat() + + // Try to get values + values, err := colType.GetValue() + assert.NoError(t, err) + + // Check if the length of values matches the expected size + expectedSize := 3 + assert.Equal(t, expectedSize, len(values)) + + // Initialize ColumnType with size 3 + colType = NewColumnType(3) + + // Add only 2 column types + colType.AddBool() + colType.AddTinyint() + + // Try to get values + _, err = colType.GetValue() + + // Check if an error is returned due to incomplete column + assert.Error(t, err) + assert.Equal(t, "incomplete column expect 3 columns set 2 columns", err.Error()) +} + +func TestNewColumnTypeWithValue(t *testing.T) { + value := []*types.ColumnType{ + {Type: types.TaosBoolType}, + {Type: types.TaosTinyintType}, + } + + colType := NewColumnTypeWithValue(value) + + expectedSize := len(value) + assert.Equal(t, expectedSize, colType.size) + + expectedValue := value + assert.Equal(t, expectedValue, colType.value) + + expectedColumn := len(value) + assert.Equal(t, expectedColumn, colType.column) +} diff --git a/common/param/param.go b/common/param/param.go index a14854b..cce09d0 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -20,6 +20,15 @@ func NewParam(size int) *Param { } } +func NewParamsWithRowValue(value []driver.Value) []*Param { + params := make([]*Param, len(value)) + for i, d := range value { + params[i] = NewParam(1) + params[i].AddValue(d) + } + return params +} + func (p *Param) SetBool(offset int, value bool) { if offset >= p.size { return diff --git a/common/param/param_test.go b/common/param/param_test.go new file mode 100644 index 0000000..8bf3a18 --- /dev/null +++ b/common/param/param_test.go @@ -0,0 +1,654 @@ +package param + +import ( + "database/sql/driver" + "testing" + "time" + + "github.com/stretchr/testify/assert" + taosTypes "github.com/taosdata/driver-go/v3/types" +) + +func TestParam_SetBool(t *testing.T) { + param := NewParam(1) + param.SetBool(0, true) + + expected := []driver.Value{taosTypes.TaosBool(true)} + assert.Equal(t, expected, param.GetValues()) + + param = NewParam(0) + param.SetBool(0, true) + assert.Equal(t, 0, len(param.GetValues())) +} + +func TestParam_SetNull(t *testing.T) { + param := NewParam(1) + param.SetNull(0) + + if param.GetValues()[0] != nil { + t.Errorf("SetNull failed, expected nil, got %v", param.GetValues()[0]) + } + param = NewParam(0) + param.SetNull(0) + assert.Equal(t, 0, len(param.GetValues())) +} + +func TestParam_SetTinyint(t *testing.T) { + param := NewParam(1) + param.SetTinyint(0, 42) + + expected := []driver.Value{taosTypes.TaosTinyint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetTinyint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetSmallint(t *testing.T) { + param := NewParam(1) + param.SetSmallint(0, 42) + + expected := []driver.Value{taosTypes.TaosSmallint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetSmallint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetInt(t *testing.T) { + param := NewParam(1) + param.SetInt(0, 42) + + expected := []driver.Value{taosTypes.TaosInt(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetInt(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetBigint(t *testing.T) { + param := NewParam(1) + param.SetBigint(0, 42) + + expected := []driver.Value{taosTypes.TaosBigint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetBigint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUTinyint(t *testing.T) { + param := NewParam(1) + param.SetUTinyint(0, 42) + + expected := []driver.Value{taosTypes.TaosUTinyint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUTinyint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUSmallint(t *testing.T) { + param := NewParam(1) + param.SetUSmallint(0, 42) + + expected := []driver.Value{taosTypes.TaosUSmallint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUSmallint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUInt(t *testing.T) { + param := NewParam(1) + param.SetUInt(0, 42) + + expected := []driver.Value{taosTypes.TaosUInt(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUInt(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetUBigint(t *testing.T) { + param := NewParam(1) + param.SetUBigint(0, 42) + + expected := []driver.Value{taosTypes.TaosUBigint(42)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetUBigint(1, 42) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetFloat(t *testing.T) { + param := NewParam(1) + param.SetFloat(0, 3.14) + + expected := []driver.Value{taosTypes.TaosFloat(3.14)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetFloat(1, 3.14) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetDouble(t *testing.T) { + param := NewParam(1) + param.SetDouble(0, 3.14) + + expected := []driver.Value{taosTypes.TaosDouble(3.14)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetDouble(1, 3.14) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetBinary(t *testing.T) { + param := NewParam(1) + param.SetBinary(0, []byte{0x01, 0x02}) + + expected := []driver.Value{taosTypes.TaosBinary([]byte{0x01, 0x02})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetBinary(1, []byte{0x01, 0x02}) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetVarBinary(t *testing.T) { + param := NewParam(1) + param.SetVarBinary(0, []byte{0x01, 0x02}) + + expected := []driver.Value{taosTypes.TaosVarBinary([]byte{0x01, 0x02})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetVarBinary(1, []byte{0x01, 0x02}) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetNchar(t *testing.T) { + param := NewParam(1) + param.SetNchar(0, "hello") + + expected := []driver.Value{taosTypes.TaosNchar("hello")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetNchar(1, "hello") // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetTimestamp(t *testing.T) { + timestamp := time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC) + param := NewParam(1) + param.SetTimestamp(0, timestamp, 6) + + expected := []driver.Value{taosTypes.TaosTimestamp{T: timestamp, Precision: 6}} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetTimestamp(1, timestamp, 6) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetJson(t *testing.T) { + jsonData := []byte(`{"key": "value"}`) + param := NewParam(1) + param.SetJson(0, jsonData) + + expected := []driver.Value{taosTypes.TaosJson(jsonData)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetJson(1, jsonData) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_SetGeometry(t *testing.T) { + geometryData := []byte{0x01, 0x02, 0x03, 0x04} + param := NewParam(1) + param.SetGeometry(0, geometryData) + + expected := []driver.Value{taosTypes.TaosGeometry(geometryData)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.SetGeometry(1, geometryData) // Attempt to set at index 1 with size 1 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddBool(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a bool value + param.AddBool(true) + + expected := []driver.Value{taosTypes.TaosBool(true), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another bool value + param.AddBool(false) + + expected = []driver.Value{taosTypes.TaosBool(true), taosTypes.TaosBool(false)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBool(true) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddNull(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a null value + param.AddNull() + + expected := []driver.Value{nil, nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another null value + param.AddNull() + + expected = []driver.Value{nil, nil} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddNull() // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddTinyint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a tinyint value + param.AddTinyint(42) + + expected := []driver.Value{taosTypes.TaosTinyint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another tinyint value + param.AddTinyint(84) + + expected = []driver.Value{taosTypes.TaosTinyint(42), taosTypes.TaosTinyint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddTinyint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddSmallint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a smallint value + param.AddSmallint(42) + + expected := []driver.Value{taosTypes.TaosSmallint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another smallint value + param.AddSmallint(84) + + expected = []driver.Value{taosTypes.TaosSmallint(42), taosTypes.TaosSmallint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddSmallint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddInt(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add an int value + param.AddInt(42) + + expected := []driver.Value{taosTypes.TaosInt(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another int value + param.AddInt(84) + + expected = []driver.Value{taosTypes.TaosInt(42), taosTypes.TaosInt(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddInt(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not mod +} + +func TestParam_AddBigint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a bigint value + param.AddBigint(42) + + expected := []driver.Value{taosTypes.TaosBigint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another bigint value + param.AddBigint(84) + + expected = []driver.Value{taosTypes.TaosBigint(42), taosTypes.TaosBigint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBigint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUTinyint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a utinyint value + param.AddUTinyint(42) + + expected := []driver.Value{taosTypes.TaosUTinyint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another utinyint value + param.AddUTinyint(84) + + expected = []driver.Value{taosTypes.TaosUTinyint(42), taosTypes.TaosUTinyint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUTinyint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUSmallint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a usmallint value + param.AddUSmallint(42) + + expected := []driver.Value{taosTypes.TaosUSmallint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another usmallint value + param.AddUSmallint(84) + + expected = []driver.Value{taosTypes.TaosUSmallint(42), taosTypes.TaosUSmallint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUSmallint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUInt(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a uint value + param.AddUInt(42) + + expected := []driver.Value{taosTypes.TaosUInt(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another uint value + param.AddUInt(84) + + expected = []driver.Value{taosTypes.TaosUInt(42), taosTypes.TaosUInt(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUInt(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddUBigint(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a ubigint value + param.AddUBigint(42) + + expected := []driver.Value{taosTypes.TaosUBigint(42), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another ubigint value + param.AddUBigint(84) + + expected = []driver.Value{taosTypes.TaosUBigint(42), taosTypes.TaosUBigint(84)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddUBigint(126) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddFloat(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a float value + param.AddFloat(3.14) + + expected := []driver.Value{taosTypes.TaosFloat(3.14), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another float value + param.AddFloat(6.28) + + expected = []driver.Value{taosTypes.TaosFloat(3.14), taosTypes.TaosFloat(6.28)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddFloat(9.42) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddDouble(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a double value + param.AddDouble(3.14) + + expected := []driver.Value{taosTypes.TaosDouble(3.14), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another double value + param.AddDouble(6.28) + + expected = []driver.Value{taosTypes.TaosDouble(3.14), taosTypes.TaosDouble(6.28)} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddDouble(9.42) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddBinary(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + binaryData := []byte{0x01, 0x02, 0x03} + + // Add a binary value + param.AddBinary(binaryData) + + expected := []driver.Value{taosTypes.TaosBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another binary value + param.AddBinary([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{taosTypes.TaosBinary(binaryData), taosTypes.TaosBinary([]byte{0x04, 0x05, 0x06})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddBinary([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddVarBinary(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + binaryData := []byte{0x01, 0x02, 0x03} + + // Add a varbinary value + param.AddVarBinary(binaryData) + + expected := []driver.Value{taosTypes.TaosVarBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another varbinary value + param.AddVarBinary([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{taosTypes.TaosVarBinary(binaryData), taosTypes.TaosVarBinary([]byte{0x04, 0x05, 0x06})} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddVarBinary([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddNchar(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add an nchar value + param.AddNchar("hello") + + expected := []driver.Value{taosTypes.TaosNchar("hello"), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another nchar value + param.AddNchar("world") + + expected = []driver.Value{taosTypes.TaosNchar("hello"), taosTypes.TaosNchar("world")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddNchar("test") // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddTimestamp(t *testing.T) { + timestamp := time.Date(2022, time.January, 1, 12, 0, 0, 0, time.UTC) + param := NewParam(2) // Initialize with size 2 + + // Add a timestamp value + param.AddTimestamp(timestamp, 6) + + expected := []driver.Value{taosTypes.TaosTimestamp{T: timestamp, Precision: 6}, nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another timestamp value + param.AddTimestamp(timestamp.Add(time.Hour), 9) + + expected = []driver.Value{ + taosTypes.TaosTimestamp{T: timestamp, Precision: 6}, + taosTypes.TaosTimestamp{T: timestamp.Add(time.Hour), Precision: 9}, + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddTimestamp(timestamp.Add(2*time.Hour), 6) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddJson(t *testing.T) { + jsonData := []byte(`{"key": "value"}`) + param := NewParam(2) // Initialize with size 2 + + // Add a JSON value + param.AddJson(jsonData) + + expected := []driver.Value{taosTypes.TaosJson(jsonData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another JSON value + param.AddJson([]byte(`{"key2": "value2"}`)) + + expected = []driver.Value{ + taosTypes.TaosJson(jsonData), + taosTypes.TaosJson([]byte(`{"key2": "value2"}`)), + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddJson([]byte(`{"key3": "value3"}`)) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddGeometry(t *testing.T) { + geometryData := []byte{0x01, 0x02, 0x03} + param := NewParam(2) // Initialize with size 2 + + // Add a geometry value + param.AddGeometry(geometryData) + + expected := []driver.Value{taosTypes.TaosGeometry(geometryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add another geometry value + param.AddGeometry([]byte{0x04, 0x05, 0x06}) + + expected = []driver.Value{ + taosTypes.TaosGeometry(geometryData), + taosTypes.TaosGeometry([]byte{0x04, 0x05, 0x06}), + } + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddGeometry([]byte{0x07, 0x08, 0x09}) // Attempt to add at index 2 with size 2 + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestParam_AddValue(t *testing.T) { + param := NewParam(2) // Initialize with size 2 + + // Add a binary value + binaryData := []byte{0x01, 0x02, 0x03} + param.AddValue(taosTypes.TaosBinary(binaryData)) + + expected := []driver.Value{taosTypes.TaosBinary(binaryData), nil} + assert.Equal(t, expected, param.GetValues()) + + // Add a varchar value + param.AddValue(taosTypes.TaosVarBinary("hello")) + + expected = []driver.Value{taosTypes.TaosBinary(binaryData), taosTypes.TaosVarBinary("hello")} + assert.Equal(t, expected, param.GetValues()) + + // Test when offset is out of range + param.AddValue(taosTypes.TaosVarBinary("world")) + assert.Equal(t, expected, param.GetValues()) // Should not modify values +} + +func TestNewParamsWithRowValue(t *testing.T) { + rowValues := []driver.Value{taosTypes.TaosBool(true), taosTypes.TaosInt(42), taosTypes.TaosNchar("hello")} + + params := NewParamsWithRowValue(rowValues) + + expected := []*Param{ + { + size: 1, + value: []driver.Value{taosTypes.TaosBool(true)}, + offset: 1, + }, + { + size: 1, + value: []driver.Value{taosTypes.TaosInt(42)}, + offset: 1, + }, + { + size: 1, + value: []driver.Value{taosTypes.TaosNchar("hello")}, + offset: 1, + }, + } + + for i, param := range params { + assert.Equal(t, expected[i].size, param.size) + assert.Equal(t, expected[i].value, param.value) + assert.Equal(t, expected[i].offset, param.offset) + } +} diff --git a/common/parser/block.go b/common/parser/block.go index 7228e90..dd7dfa1 100644 --- a/common/parser/block.go +++ b/common/parser/block.go @@ -3,6 +3,7 @@ package parser import ( "database/sql/driver" "math" + "unicode/utf8" "unsafe" "github.com/taosdata/driver-go/v3/common" @@ -108,28 +109,9 @@ type rawConvertFunc func(pStart unsafe.Pointer, row int, arg ...interface{}) dri type rawConvertVarDataFunc func(pHeader, pStart unsafe.Pointer, row int) driver.Value -var rawConvertFuncMap = map[uint8]rawConvertFunc{ - uint8(common.TSDB_DATA_TYPE_BOOL): rawConvertBool, - uint8(common.TSDB_DATA_TYPE_TINYINT): rawConvertTinyint, - uint8(common.TSDB_DATA_TYPE_SMALLINT): rawConvertSmallint, - uint8(common.TSDB_DATA_TYPE_INT): rawConvertInt, - uint8(common.TSDB_DATA_TYPE_BIGINT): rawConvertBigint, - uint8(common.TSDB_DATA_TYPE_UTINYINT): rawConvertUTinyint, - uint8(common.TSDB_DATA_TYPE_USMALLINT): rawConvertUSmallint, - uint8(common.TSDB_DATA_TYPE_UINT): rawConvertUInt, - uint8(common.TSDB_DATA_TYPE_UBIGINT): rawConvertUBigint, - uint8(common.TSDB_DATA_TYPE_FLOAT): rawConvertFloat, - uint8(common.TSDB_DATA_TYPE_DOUBLE): rawConvertDouble, - uint8(common.TSDB_DATA_TYPE_TIMESTAMP): rawConvertTime, -} - -var rawConvertVarDataMap = map[uint8]rawConvertVarDataFunc{ - uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, - uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, - uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, - uint8(common.TSDB_DATA_TYPE_VARBINARY): rawConvertVarBinary, - uint8(common.TSDB_DATA_TYPE_GEOMETRY): rawConvertGeometry, -} +var rawConvertFuncSlice = [15]rawConvertFunc{} + +var rawConvertVarDataSlice = [21]rawConvertVarDataFunc{} func ItemIsNull(pHeader unsafe.Pointer, row int) bool { offset := CharOffset(row) @@ -196,20 +178,27 @@ func rawConvertTime(pStart unsafe.Pointer, row int, arg ...interface{}) driver.V } func rawConvertVarBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + result := rawGetBytes(pHeader, pStart, row) + if result == nil { + return nil + } + return result +} + +func rawGetBytes(pHeader, pStart unsafe.Pointer, row int) []byte { offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) if offset == -1 { return nil } currentRow := pointer.AddUintptr(pStart, uintptr(offset)) clen := *((*uint16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + if clen == 0 { + return make([]byte, 0) } - return binaryVal[:] + currentRow = pointer.AddUintptr(currentRow, 2) + result := make([]byte, clen) + Copy(currentRow, result, 0, int(clen)) + return result } func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { @@ -217,20 +206,11 @@ func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { } func rawConvertBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { - offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) - if offset == -1 { + result := rawGetBytes(pHeader, pStart, row) + if result == nil { return nil } - currentRow := pointer.AddUintptr(pStart, uintptr(offset)) - clen := *((*uint16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) - } - return string(binaryVal[:]) + return *(*string)(unsafe.Pointer(&result)) } func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { @@ -240,31 +220,34 @@ func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { } currentRow := pointer.AddUintptr(pStart, uintptr(offset)) clen := *((*uint16)(currentRow)) / 4 + if clen == 0 { + return "" + } currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]rune, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) + utf8Bytes := make([]byte, clen*utf8.UTFMax) + index := 0 + utf32Slice := (*[1 << 30]rune)(currentRow)[:clen:clen] + for _, runeValue := range utf32Slice { + index += utf8.EncodeRune(utf8Bytes[index:], runeValue) } - return string(binaryVal) + utf8Bytes = utf8Bytes[:index] + return *(*string)(unsafe.Pointer(&utf8Bytes)) } func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { - offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) - if offset == -1 { - return nil - } - currentRow := pointer.AddUintptr(pStart, uintptr(offset)) - clen := *((*uint16)(currentRow)) - currentRow = pointer.AddUintptr(currentRow, 2) - - binaryVal := make([]byte, clen) + return rawConvertVarBinary(pHeader, pStart, row) +} - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(pointer.AddUintptr(currentRow, uintptr(index)))) +func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { + blockSize := RawBlockGetNumOfRows(block) + colCount := RawBlockGetNumOfCols(block) + colInfo := make([]RawBlockColInfo, colCount) + RawBlockGetColInfo(block, colInfo) + colTypes := make([]uint8, colCount) + for i := int32(0); i < colCount; i++ { + colTypes[i] = uint8(colInfo[i].ColType) } - return binaryVal[:] + return ReadBlock(block, int(blockSize), colTypes, precision) } // ReadBlock in-place @@ -278,7 +261,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { @@ -287,7 +270,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision r[row][column] = convertF(pHeader, pStart, row) } } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { @@ -314,11 +297,11 @@ func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) dest[column] = convertF(pHeader, pStart, row) } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) if ItemIsNull(pHeader, row) { dest[column] = nil @@ -340,7 +323,7 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, uintptr(4*blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { @@ -349,7 +332,7 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin r[row][column] = convertF(pHeader, pStart, row) } } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { @@ -369,12 +352,33 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin func ItemRawBlock(colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) driver.Value { if IsVarDataType(colType) { - return rawConvertVarDataMap[colType](pHeader, pStart, row) + return rawConvertVarDataSlice[colType](pHeader, pStart, row) } else { if ItemIsNull(pHeader, row) { return nil } else { - return rawConvertFuncMap[colType](pStart, row, precision, timeFormat) + return rawConvertFuncSlice[colType](pStart, row, precision, timeFormat) } } } + +func init() { + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BOOL)] = rawConvertBool + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TINYINT)] = rawConvertTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_SMALLINT)] = rawConvertSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_INT)] = rawConvertInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BIGINT)] = rawConvertBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UTINYINT)] = rawConvertUTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_USMALLINT)] = rawConvertUSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UINT)] = rawConvertUInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UBIGINT)] = rawConvertUBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_FLOAT)] = rawConvertFloat + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_DOUBLE)] = rawConvertDouble + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TIMESTAMP)] = rawConvertTime + + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_BINARY)] = rawConvertBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_NCHAR)] = rawConvertNchar + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_JSON)] = rawConvertJson + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_VARBINARY)] = rawConvertVarBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_GEOMETRY)] = rawConvertGeometry +} diff --git a/common/parser/block_test.go b/common/parser/block_test.go index 5b7232a..42b2d26 100644 --- a/common/parser/block_test.go +++ b/common/parser/block_test.go @@ -663,12 +663,6 @@ func TestParseBlock(t *testing.T) { t.Error(errors.NewError(code, errStr)) return } - fileCount := wrapper.TaosNumFields(res) - rh, err := wrapper.ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } precision := wrapper.TaosResultPrecision(res) var data [][]driver.Value for { @@ -684,7 +678,7 @@ func TestParseBlock(t *testing.T) { break } version := RawBlockGetVersion(block) - assert.Equal(t, int32(1), version) + t.Log(version) length := RawBlockGetLength(block) assert.Equal(t, int32(447), length) rows := RawBlockGetNumOfRows(block) @@ -771,7 +765,7 @@ func TestParseBlock(t *testing.T) { }, infos, ) - d := ReadBlock(block, blockSize, rh.ColTypes, precision) + d := ReadBlockSimple(block, precision) data = append(data, d...) } wrapper.TaosFreeResult(res) diff --git a/common/parser/mem.go b/common/parser/mem.go new file mode 100644 index 0000000..f0d4b00 --- /dev/null +++ b/common/parser/mem.go @@ -0,0 +1,12 @@ +package parser + +import "unsafe" + +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +//go:linkname memmove runtime.memmove + +func Copy(source unsafe.Pointer, data []byte, index int, length int) { + memmove(unsafe.Pointer(&data[index]), source, uintptr(length)) +} diff --git a/common/parser/mem.s b/common/parser/mem.s new file mode 100644 index 0000000..e69de29 diff --git a/common/parser/mem_test.go b/common/parser/mem_test.go new file mode 100644 index 0000000..d3e244b --- /dev/null +++ b/common/parser/mem_test.go @@ -0,0 +1,20 @@ +package parser + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestCopy(t *testing.T) { + data := []byte("World") + data1 := make([]byte, 10) + data1[0] = 'H' + data1[1] = 'e' + data1[2] = 'l' + data1[3] = 'l' + data1[4] = 'o' + Copy(unsafe.Pointer(&data[0]), data1, 5, 5) + assert.Equal(t, "HelloWorld", string(data1)) +} diff --git a/common/parser/raw.go b/common/parser/raw.go new file mode 100644 index 0000000..61ad125 --- /dev/null +++ b/common/parser/raw.go @@ -0,0 +1,185 @@ +package parser + +import ( + "fmt" + "unsafe" + + "github.com/taosdata/driver-go/v3/common/pointer" +) + +type TMQRawDataParser struct { + block unsafe.Pointer + offset uintptr +} + +func NewTMQRawDataParser() *TMQRawDataParser { + return &TMQRawDataParser{} +} + +type TMQBlockInfo struct { + RawBlock unsafe.Pointer + Precision int + Schema []*TMQRawDataSchema + TableName string +} + +type TMQRawDataSchema struct { + ColType uint8 + Flag int8 + Bytes int64 + ColID int + Name string +} + +func (p *TMQRawDataParser) getTypeSkip(t int8) (int, error) { + skip := 8 + switch t { + case 1: + case 2, 3: + skip = 16 + default: + return 0, fmt.Errorf("unknown type %d", t) + } + return skip, nil +} + +func (p *TMQRawDataParser) skipHead() error { + v := p.parseInt8() + if v >= 100 { + skip := p.parseInt32() + p.skip(int(skip)) + return nil + } else { + skip, err := p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + v = p.parseInt8() + skip, err = p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + return nil + } +} + +func (p *TMQRawDataParser) skip(count int) { + p.offset += uintptr(count) +} + +func (p *TMQRawDataParser) parseBlockInfos() []*TMQBlockInfo { + blockNum := p.parseInt32() + blockInfos := make([]*TMQBlockInfo, blockNum) + withTableName := p.parseBool() + withSchema := p.parseBool() + for i := int32(0); i < blockNum; i++ { + blockInfo := &TMQBlockInfo{} + blockTotalLen := p.parseVariableByteInteger() + p.skip(17) + blockInfo.Precision = int(p.parseUint8()) + blockInfo.RawBlock = pointer.AddUintptr(p.block, p.offset) + p.skip(blockTotalLen - 18) + if withSchema { + cols := p.parseZigzagVariableByteInteger() + //version + _ = p.parseZigzagVariableByteInteger() + + blockInfo.Schema = make([]*TMQRawDataSchema, cols) + for j := 0; j < cols; j++ { + blockInfo.Schema[j] = p.parseSchema() + } + } + if withTableName { + blockInfo.TableName = p.parseName() + } + blockInfos[i] = blockInfo + } + return blockInfos +} + +func (p *TMQRawDataParser) parseZigzagVariableByteInteger() int { + return zigzagDecode(p.parseVariableByteInteger()) +} + +func (p *TMQRawDataParser) parseBool() bool { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v != 0 +} + +func (p *TMQRawDataParser) parseUint8() uint8 { + v := *(*uint8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt8() int8 { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt32() int32 { + v := *(*int32)(pointer.AddUintptr(p.block, p.offset)) + p.skip(4) + return v +} + +func (p *TMQRawDataParser) parseSchema() *TMQRawDataSchema { + colType := p.parseUint8() + flag := p.parseInt8() + bytes := int64(p.parseZigzagVariableByteInteger()) + colID := p.parseZigzagVariableByteInteger() + name := p.parseName() + return &TMQRawDataSchema{ + ColType: colType, + Flag: flag, + Bytes: bytes, + ColID: colID, + Name: name, + } +} + +func (p *TMQRawDataParser) parseName() string { + nameLen := p.parseVariableByteInteger() + name := make([]byte, nameLen-1) + for i := 0; i < nameLen-1; i++ { + name[i] = *(*byte)(pointer.AddUintptr(p.block, p.offset+uintptr(i))) + } + p.skip(nameLen) + return string(name) +} + +func (p *TMQRawDataParser) Parse(block unsafe.Pointer) ([]*TMQBlockInfo, error) { + p.reset(block) + err := p.skipHead() + if err != nil { + return nil, err + } + return p.parseBlockInfos(), nil +} + +func (p *TMQRawDataParser) reset(block unsafe.Pointer) { + p.block = block + p.offset = 0 +} + +func (p *TMQRawDataParser) parseVariableByteInteger() int { + multiplier := 1 + value := 0 + for { + encodedByte := p.parseUint8() + value += int(encodedByte&127) * multiplier + if encodedByte&128 == 0 { + break + } + multiplier *= 128 + } + return value +} + +func zigzagDecode(n int) int { + return (n >> 1) ^ (-(n & 1)) +} diff --git a/common/parser/raw_test.go b/common/parser/raw_test.go new file mode 100644 index 0000000..521a626 --- /dev/null +++ b/common/parser/raw_test.go @@ -0,0 +1,1049 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, 0x00, 0x00, 0x00, + + 0x01, + 0x01, + + 0xc5, 0x01, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x02, + + 0x02, 0x00, 0x00, 0x00, + 0xb3, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x82, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, + + 0x00, + 0xc0, 0xed, 0x82, 0x05, 0xc3, 0x1b, 0xab, 0x17, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x5a, 0x00, + 0x61, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x61, + + 0x08, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x03, + 0x63, 0x31, 0x00, + + 0x06, + 0x01, + 0x08, + 0x06, + 0x03, + 0x63, 0x32, 0x00, + + 0x08, + 0x01, + 0x84, 0x02, + 0x08, + 0x03, 0x63, 0x33, 0x00, + + 0x05, + 0x63, 0x74, 0x62, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 2, blockInfos[0].Precision) + assert.Equal(t, 4, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "c1", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 3, + Name: "c2", + }, + { + ColType: 8, + Flag: 1, + Bytes: 130, + ColID: 4, + Name: "c3", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "ctb0", blockInfos[0].TableName) +} + +func TestParseTwoBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x02, 0x00, 0x00, 0x00, + + 0x00, // withTbName false + 0x01, // withSchema true + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf8, 0x6b, 0x75, 0x35, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x30, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf9, 0x6b, 0x75, 0x35, + 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x31, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 2, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 0, blockInfos[1].Precision) + assert.Equal(t, 3, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[0].Schema) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[1].Schema) + assert.Equal(t, "", blockInfos[0].TableName) + assert.Equal(t, "", blockInfos[1].TableName) +} + +func TestParseTenBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, + 0x01, + + // block1 + 0x4e, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x31, 0x00, + + //block2 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x32, 0x00, + + //block3 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x33, 0x00, + + //block4 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x34, 0x00, + + // block5 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x35, 0x00, + + //block6 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x36, 0x00, + + //block7 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x07, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x37, 0x00, + + //block8 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x38, 0x00, + + //block9 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x39, 0x00, + + //block10 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x0a, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + 0x04, + 0x74, 0x31, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 10, len(blockInfos)) + for i := 0; i < 10; i++ { + assert.Equal(t, 0, blockInfos[i].Precision) + assert.Equal(t, 2, len(blockInfos[i].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "v", + }, + }, blockInfos[i].Schema) + assert.Equal(t, fmt.Sprintf("t%d", i+1), blockInfos[i].TableName) + value := ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision) + ts := time.Unix(0, 1706081119570000000).Local() + assert.Equal(t, [][]driver.Value{{ts, int32(i + 1)}}, value) + } +} + +func TestVersion100Block(t *testing.T) { + data := []byte{ + 0x64, //version + 0x12, 0x00, 0x00, 0x00, // skip 18 bytes + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, //block count 1 + + 0x01, // with table name + 0x01, // with schema + + 0x92, 0x02, // block length 274 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, // 256 + 0x01, 0x00, 0x00, 0x00, // rows + 0x0e, 0x00, 0x00, 0x00, // cols + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x16, 0x00, 0x00, 0x00, + 0x0a, 0x52, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + + 0x00, + 0x9e, 0x37, 0x6a, 0x04, 0x8f, 0x01, 0x00, 0x00, + + 0x00, + 0x01, + + 0x00, + 0x02, + + 0x00, + 0x03, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0x06, + + 0x00, + 0x07, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0xcf, 0xf7, 0x21, 0x41, + + 0x00, + 0xe5, 0xd0, 0x22, 0xdb, 0xf9, 0x3e, 0x26, 0x40, + + 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, + 0x6e, 0x00, 0x00, 0x00, + 0x63, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, + 0x72, 0x00, 0x00, 0x00, + + 0x00, // + + 0x1c, // cols 14 + 0x00, // version + + // col meta + 0x09, 0x01, 0x10, 0x02, 0x03, 0x74, 0x73, 0x00, + 0x01, 0x01, 0x02, 0x04, 0x03, 0x63, 0x31, 0x00, + 0x02, 0x01, 0x02, 0x06, 0x03, 0x63, 0x32, 0x00, + 0x03, 0x01, 0x04, 0x08, 0x03, 0x63, 0x33, 0x00, + 0x04, 0x01, 0x08, 0x0a, 0x03, 0x63, 0x34, 0x00, + 0x05, 0x01, 0x10, 0x0c, 0x03, 0x63, 0x35, 0x00, + 0x0b, 0x01, 0x02, 0x0e, 0x03, 0x63, 0x36, 0x00, + 0x0c, 0x01, 0x04, 0x10, 0x03, 0x63, 0x37, 0x00, + 0x0d, 0x01, 0x08, 0x12, 0x03, 0x63, 0x38, 0x00, + 0x0e, 0x01, 0x10, 0x14, 0x03, 0x63, 0x39, 0x00, + 0x06, 0x01, 0x08, 0x16, 0x04, 0x63, 0x31, 0x30, 0x00, + 0x07, 0x01, 0x10, 0x18, 0x04, 0x63, 0x31, 0x31, 0x00, + 0x08, 0x01, 0x2c, 0x1a, 0x04, 0x63, 0x31, 0x32, 0x00, + 0x0a, 0x01, 0xa4, 0x01, 0x1c, 0x04, 0x63, 0x31, 0x33, 0x00, + + 0x06, // table name + 0x74, 0x5f, 0x61, 0x6c, 0x6c, 0x00, + // sleep time + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 14, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 1, + Flag: 1, + Bytes: 1, + ColID: 2, + Name: "c1", + }, + { + ColType: 2, + Flag: 1, + Bytes: 1, + ColID: 3, + Name: "c2", + }, + { + ColType: 3, + Flag: 1, + Bytes: 2, + ColID: 4, + Name: "c3", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 5, + Name: "c4", + }, + { + ColType: 5, + Flag: 1, + Bytes: 8, + ColID: 6, + Name: "c5", + }, + { + ColType: 11, + Flag: 1, + Bytes: 1, + ColID: 7, + Name: "c6", + }, + { + ColType: 12, + Flag: 1, + Bytes: 2, + ColID: 8, + Name: "c7", + }, + { + ColType: 13, + Flag: 1, + Bytes: 4, + ColID: 9, + Name: "c8", + }, + { + ColType: 14, + Flag: 1, + Bytes: 8, + ColID: 10, + Name: "c9", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 11, + Name: "c10", + }, + { + ColType: 7, + Flag: 1, + Bytes: 8, + ColID: 12, + Name: "c11", + }, + { + ColType: 8, + Flag: 1, + Bytes: 22, + ColID: 13, + Name: "c12", + }, + { + ColType: 10, + Flag: 1, + Bytes: 82, + ColID: 14, + Name: "c13", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "t_all", blockInfos[0].TableName) + value := ReadBlockSimple(blockInfos[0].RawBlock, blockInfos[0].Precision) + expect := []driver.Value{ + time.Unix(0, 1713766021022000000).Local(), + true, + int8(2), + int16(3), + int32(4), + int64(5), + uint8(6), + uint16(7), + uint32(8), + uint64(9), + float32(10.123), + float64(11.123), + "binary", + "nchar", + } + assert.Equal(t, [][]driver.Value{expect}, value) +} diff --git a/common/pointer/unsafe_test.go b/common/pointer/unsafe_test.go new file mode 100644 index 0000000..05e0e8b --- /dev/null +++ b/common/pointer/unsafe_test.go @@ -0,0 +1,18 @@ +package pointer + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestAddUintptr(t *testing.T) { + data := []byte{1, 2, 3, 4, 5} + p1 := unsafe.Pointer(&data[0]) + p2 := AddUintptr(p1, 1) + assert.Equal(t, unsafe.Pointer(&data[1]), p2) + v2 := *(*byte)(p2) + assert.Equal(t, byte(2), v2) + +} diff --git a/common/serializer/block.go b/common/serializer/block.go index 03a53f2..50d7a6e 100644 --- a/common/serializer/block.go +++ b/common/serializer/block.go @@ -37,7 +37,7 @@ func BMSetNull(c byte, n int) byte { return c + (1 << (7 - BitPos(n))) } -var ColumnNumerNotMatch = errors.New("number of columns does not match") +var ColumnNumberNotMatch = errors.New("number of columns does not match") var DataTypeWrong = errors.New("wrong data type") func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { @@ -48,7 +48,7 @@ func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte return nil, err } if len(colTypes) != columns { - return nil, ColumnNumerNotMatch + return nil, ColumnNumberNotMatch } var block []byte //version int32 diff --git a/common/stmt/field_test.go b/common/stmt/field_test.go new file mode 100644 index 0000000..d50678b --- /dev/null +++ b/common/stmt/field_test.go @@ -0,0 +1,143 @@ +package stmt + +import ( + "testing" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/types" +) + +func TestGetType(t *testing.T) { + tests := []struct { + name string + fieldType int8 + want *types.ColumnType + wantErr bool + }{ + { + name: "Test Bool Type", + fieldType: common.TSDB_DATA_TYPE_BOOL, + want: &types.ColumnType{Type: types.TaosBoolType}, + wantErr: false, + }, + { + name: "Test TinyInt Type", + fieldType: common.TSDB_DATA_TYPE_TINYINT, + want: &types.ColumnType{Type: types.TaosTinyintType}, + wantErr: false, + }, + { + name: "Test SmallInt Type", + fieldType: common.TSDB_DATA_TYPE_SMALLINT, + want: &types.ColumnType{Type: types.TaosSmallintType}, + wantErr: false, + }, + { + name: "Test Int Type", + fieldType: common.TSDB_DATA_TYPE_INT, + want: &types.ColumnType{Type: types.TaosIntType}, + wantErr: false, + }, + { + name: "Test BigInt Type", + fieldType: common.TSDB_DATA_TYPE_BIGINT, + want: &types.ColumnType{Type: types.TaosBigintType}, + wantErr: false, + }, + { + name: "Test UTinyInt Type", + fieldType: common.TSDB_DATA_TYPE_UTINYINT, + want: &types.ColumnType{Type: types.TaosUTinyintType}, + wantErr: false, + }, + { + name: "Test USmallInt Type", + fieldType: common.TSDB_DATA_TYPE_USMALLINT, + want: &types.ColumnType{Type: types.TaosUSmallintType}, + wantErr: false, + }, + { + name: "Test UInt Type", + fieldType: common.TSDB_DATA_TYPE_UINT, + want: &types.ColumnType{Type: types.TaosUIntType}, + wantErr: false, + }, + { + name: "Test UBigInt Type", + fieldType: common.TSDB_DATA_TYPE_UBIGINT, + want: &types.ColumnType{Type: types.TaosUBigintType}, + wantErr: false, + }, + { + name: "Test Float Type", + fieldType: common.TSDB_DATA_TYPE_FLOAT, + want: &types.ColumnType{Type: types.TaosFloatType}, + wantErr: false, + }, + { + name: "Test Double Type", + fieldType: common.TSDB_DATA_TYPE_DOUBLE, + want: &types.ColumnType{Type: types.TaosDoubleType}, + wantErr: false, + }, + { + name: "Test Binary Type", + fieldType: common.TSDB_DATA_TYPE_BINARY, + want: &types.ColumnType{Type: types.TaosBinaryType}, + wantErr: false, + }, + { + name: "Test VarBinary Type", + fieldType: common.TSDB_DATA_TYPE_VARBINARY, + want: &types.ColumnType{Type: types.TaosVarBinaryType}, + wantErr: false, + }, + { + name: "Test Nchar Type", + fieldType: common.TSDB_DATA_TYPE_NCHAR, + want: &types.ColumnType{Type: types.TaosNcharType}, + wantErr: false, + }, + { + name: "Test Timestamp Type", + fieldType: common.TSDB_DATA_TYPE_TIMESTAMP, + want: &types.ColumnType{Type: types.TaosTimestampType}, + wantErr: false, + }, + { + name: "Test Json Type", + fieldType: common.TSDB_DATA_TYPE_JSON, + want: &types.ColumnType{Type: types.TaosJsonType}, + wantErr: false, + }, + { + name: "Test Geometry Type", + fieldType: common.TSDB_DATA_TYPE_GEOMETRY, + want: &types.ColumnType{Type: types.TaosGeometryType}, + wantErr: false, + }, + { + name: "Test Unsupported Type", + fieldType: 0, // An undefined type + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &StmtField{ + FieldType: tt.fieldType, + } + + got, err := s.GetType() + if (err != nil) != tt.wantErr { + t.Errorf("GetType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil && tt.want != nil && got.Type != tt.want.Type { + t.Errorf("GetType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/common/tmq/config_test.go b/common/tmq/config_test.go new file mode 100644 index 0000000..c5a62d5 --- /dev/null +++ b/common/tmq/config_test.go @@ -0,0 +1,52 @@ +package tmq + +import ( + "fmt" + "reflect" + "testing" +) + +func TestConfigMap_Get(t *testing.T) { + t.Parallel() + + config := ConfigMap{ + "key1": "value1", + "key2": 123, + } + + t.Run("Existing Key", func(t *testing.T) { + want := "value1" + if got, err := config.Get("key1", nil); err != nil || got != want { + t.Errorf("Get() = %v, want %v (error: %v)", got, want, err) + } + }) + + t.Run("Type Mismatch", func(t *testing.T) { + wantErr := fmt.Errorf("key2 expects type string, not int") + if got, err := config.Get("key2", "default"); err == nil || got != nil || err.Error() != wantErr.Error() { + t.Errorf("Get() = %v, want error: %v", got, wantErr) + } + }) + + t.Run("Non-Existing Key with Default Value", func(t *testing.T) { + want := "default" + if got, err := config.Get("key3", "default"); err != nil || got != want { + t.Errorf("Get() = %v, want %v (error: %v)", got, want, err) + } + }) +} + +func TestConfigMap_Clone(t *testing.T) { + t.Parallel() + + config := ConfigMap{ + "key1": "value1", + "key2": 123, + } + + clone := config.Clone() + + if !reflect.DeepEqual(config, clone) { + t.Errorf("Clone() = %v, want %v", clone, config) + } +} diff --git a/common/tmq/event_test.go b/common/tmq/event_test.go new file mode 100644 index 0000000..5043a28 --- /dev/null +++ b/common/tmq/event_test.go @@ -0,0 +1,352 @@ +package tmq + +import ( + "database/sql/driver" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + taosError "github.com/taosdata/driver-go/v3/errors" +) + +func TestDataMessage_String(t *testing.T) { + t.Parallel() + + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + message := &DataMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + data: data, + offset: 100, + } + + want := `DataMessage: test-topic[test-db]:[{"TableName":"table1","Data":[[1,"data1"]]},{"TableName":"table2","Data":[[2,"data2"]]}]` + + if got := message.String(); got != want { + t.Errorf("DataMessage.String() = %v, want %v", got, want) + } +} + +func TestMetaMessage_String(t *testing.T) { + t.Parallel() + + meta := &Meta{ + Type: "type", + TableName: "table", + TableType: "tableType", + } + message := &MetaMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + offset: 100, + meta: meta, + } + + want := `MetaMessage: test-topic[test-db]:{"type":"type","tableName":"table","tableType":"tableType","createList":null,"columns":null,"using":"","tagNum":0,"tags":null,"tableNameList":null,"alterType":0,"colName":"","colNewName":"","colType":0,"colLength":0,"colValue":"","colValueNull":false}` + + if got := message.String(); got != want { + t.Errorf("MetaMessage.String() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_String(t *testing.T) { + t.Parallel() + + meta := &Meta{ + Type: "type", + TableName: "table", + TableType: "tableType", + } + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + metaData := &MetaData{ + Meta: meta, + Data: data, + } + message := &MetaDataMessage{ + TopicPartition: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + dbName: "test-db", + topic: "test-topic", + offset: 100, + metaData: metaData, + } + + want := `MetaDataMessage: test-topic[test-db]:{"Meta":{"type":"type","tableName":"table","tableType":"tableType","createList":null,"columns":null,"using":"","tagNum":0,"tags":null,"tableNameList":null,"alterType":0,"colName":"","colNewName":"","colType":0,"colLength":0,"colValue":"","colValueNull":false},"Data":[{"TableName":"table1","Data":[[1,"data1"]]},{"TableName":"table2","Data":[[2,"data2"]]}]}` + if got := message.String(); got != want { + t.Errorf("MetaDataMessage.String() = %v, want %v", got, want) + } +} + +func TestNewTMQError(t *testing.T) { + t.Parallel() + + code := 123 + str := "test error" + err := NewTMQError(code, str) + + if err.code != code { + t.Errorf("NewTMQError() code = %v, want %v", err.code, code) + } + + if err.str != str { + t.Errorf("NewTMQError() str = %v, want %v", err.str, str) + } +} + +func TestNewTMQErrorWithErr(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + err error + code int + str string + }{ + { + name: "TaosError", + err: &taosError.TaosError{ + Code: 456, + ErrStr: "taos error", + }, + code: 456, + str: "taos error", + }, + { + name: "OtherError", + err: fmt.Errorf("other error"), + code: ErrorOther, + str: "other error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := NewTMQErrorWithErr(tc.err) + + if err.code != tc.code { + t.Errorf("NewTMQErrorWithErr() code = %v, want %v", err.code, tc.code) + } + + if err.str != tc.str { + t.Errorf("NewTMQErrorWithErr() str = %v, want %v", err.str, tc.str) + } + }) + } +} + +func TestError_String(t *testing.T) { + t.Parallel() + + code := 789 + str := "test error" + err := Error{code: code, str: str} + want := fmt.Sprintf("[0x%x] %s", code, str) + + if got := err.String(); got != want { + t.Errorf("Error.String() = %v, want %v", got, want) + } +} + +func TestError_Error(t *testing.T) { + t.Parallel() + + code := 789 + str := "test error" + err := Error{code: code, str: str} + want := fmt.Sprintf("[0x%x] %s", code, str) + + if got := err.Error(); got != want { + t.Errorf("Error.Error() = %v, want %v", got, want) + } +} + +func TestError_Code(t *testing.T) { + t.Parallel() + + code := 789 + err := Error{code: code} + + if got := err.Code(); got != code { + t.Errorf("Error.Code() = %v, want %v", got, code) + } +} + +func TestMetaMessage_Offset(t *testing.T) { + t.Parallel() + + message := &MetaMessage{ + offset: 100, + } + + want := Offset(100) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &MetaMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaMessage_SetMeta(t *testing.T) { + t.Parallel() + + meta := &Meta{} + message := &MetaMessage{} + message.SetMeta(meta) + + want := meta + if got := message.Value(); got != want { + t.Errorf("Value() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestDataMessage_SetData(t *testing.T) { + t.Parallel() + + data := []*Data{ + {TableName: "table1", Data: [][]driver.Value{{1, "data1"}}}, + {TableName: "table2", Data: [][]driver.Value{{2, "data2"}}}, + } + message := &DataMessage{} + message.SetData(data) + + want := data + assert.Equal(t, want, message.Value()) +} + +func TestDataMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &DataMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetDbName(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetDbName("test-db") + + want := "test-db" + if got := message.DBName(); got != want { + t.Errorf("DBName() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetTopic(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetTopic("test-topic") + + want := "test-topic" + if got := message.Topic(); got != want { + t.Errorf("Topic() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetOffset(t *testing.T) { + t.Parallel() + + message := &MetaDataMessage{} + message.SetOffset(200) + + want := Offset(200) + if got := message.Offset(); got != want { + t.Errorf("Offset() = %v, want %v", got, want) + } +} + +func TestMetaDataMessage_SetMetaData(t *testing.T) { + t.Parallel() + + metaData := &MetaData{} + message := &MetaDataMessage{} + message.SetMetaData(metaData) + + want := metaData + if got := message.Value(); got != want { + t.Errorf("Value() = %v, want %v", got, want) + } +} diff --git a/common/tmq/tmq_test.go b/common/tmq/tmq_test.go index 7279c1d..ae9fefb 100644 --- a/common/tmq/tmq_test.go +++ b/common/tmq/tmq_test.go @@ -2,6 +2,8 @@ package tmq import ( "encoding/json" + "errors" + "reflect" "testing" ) @@ -66,3 +68,130 @@ func TestDropJson(t *testing.T) { } t.Log(obj) } + +func TestOffset_String(t *testing.T) { + tests := []struct { + name string + o Offset + want string + }{ + { + name: "Valid Offset", + o: 100, + want: "100", + }, + { + name: "Invalid Offset", + o: OffsetInvalid, + want: "unset", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.o.String(); got != tt.want { + t.Errorf("Offset.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestOffset_Valid(t *testing.T) { + tests := []struct { + name string + o Offset + want bool + }{ + { + name: "Valid Offset", + o: 100, + want: true, + }, + { + name: "Invalid Offset", + o: OffsetInvalid, + want: true, + }, + { + name: "Negative Offset", + o: -100, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.o.Valid(); got != tt.want { + t.Errorf("Offset.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTopicPartition_String(t *testing.T) { + tests := []struct { + name string + tp TopicPartition + want string + }{ + { + name: "With Error", + tp: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + Error: errors.New("error message"), + }, + want: "test-topic[0]@100(error message)", + }, + { + name: "Without Error", + tp: TopicPartition{ + Topic: stringPtr("test-topic"), + Partition: 0, + Offset: 100, + }, + want: "test-topic[0]@100", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.tp.String(); got != tt.want { + t.Errorf("TopicPartition.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAssignment_MarshalJSON(t *testing.T) { + tests := []struct { + name string + a Assignment + want string + }{ + { + name: "Marshal Assignment", + a: Assignment{ + VGroupID: 1, + Offset: 100, + Begin: 50, + End: 150, + }, + want: `{"vgroup_id":1,"offset":100,"begin":50,"end":150}`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.a) + if err != nil { + t.Errorf("MarshalJSON error: %v", err) + return + } + if !reflect.DeepEqual(string(got), tt.want) { + t.Errorf("MarshalJSON = %v, want %v", string(got), tt.want) + } + }) + } +} + +func stringPtr(s string) *string { + return &s +} diff --git a/errors/errors_test.go b/errors/errors_test.go index e29f9ae..0064c90 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -1,6 +1,10 @@ package errors -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) // @author: xftan // @date: 2023/10/13 11:20 @@ -32,3 +36,13 @@ func TestNewError(t *testing.T) { }) } } + +func TestError(t *testing.T) { + invalidError := ErrTscInvalidConnection.Error() + assert.Equal(t, "[0x20b] Invalid connection", invalidError) + unknownError := &TaosError{ + Code: 0xffff, + ErrStr: "unknown error", + } + assert.Equal(t, "unknown error", unknownError.Error()) +} diff --git a/examples/stmtoverws/main.go b/examples/stmtoverws/main.go index a9c0d5a..6e083df 100644 --- a/examples/stmtoverws/main.go +++ b/examples/stmtoverws/main.go @@ -19,7 +19,7 @@ func main() { defer db.Close() prepareEnv(db) - config := stmt.NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := stmt.NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("example_ws_stmt") diff --git a/examples/tmq/main.go b/examples/tmq/main.go index eb110ce..01e2010 100644 --- a/examples/tmq/main.go +++ b/examples/tmq/main.go @@ -27,16 +27,15 @@ func main() { panic(err) } consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "group.id": "test", - "auto.offset.reset": "earliest", - "td.connect.ip": "127.0.0.1", - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "td.connect.port": "6030", - "client.id": "test_tmq_client", - "enable.auto.commit": "false", - "experimental.snapshot.enable": "true", - "msg.with.table.name": "true", + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_client", + "enable.auto.commit": "false", + "msg.with.table.name": "true", }) if err != nil { panic(err) diff --git a/examples/tmqoverws/main.go b/examples/tmqoverws/main.go index b0fdf91..cac691a 100644 --- a/examples/tmqoverws/main.go +++ b/examples/tmqoverws/main.go @@ -18,7 +18,7 @@ func main() { defer db.Close() prepareEnv(db) consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, diff --git a/taosRestful/connection.go b/taosRestful/connection.go index b139991..4c27e04 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -3,6 +3,7 @@ package taosRestful import ( "compress/gzip" "context" + "crypto/tls" "database/sql/driver" "encoding/base64" "errors" @@ -48,18 +49,24 @@ func newTaosConn(cfg *config) (*taosConn, error) { readBufferSize = 4 << 10 } tc := &taosConn{cfg: cfg, readBufferSize: readBufferSize} + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableCompression: cfg.disableCompression, + } + if cfg.skipVerify { + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } tc.client = &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - DisableCompression: cfg.disableCompression, - }, + Transport: transport, } path := "/rest/sql" if len(cfg.dbName) != 0 { @@ -230,6 +237,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) ( return nil, fmt.Errorf("server response: %s - %s", resp.Status, string(body)) } respBody := resp.Body + defer ioutil.ReadAll(respBody) if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { respBody, err = gzip.NewReader(resp.Body) if err != nil { @@ -339,6 +347,16 @@ func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, e row[column] = iter.ReadUint32() case common.TSDB_DATA_TYPE_UBIGINT: row[column] = iter.ReadUint64() + case common.TSDB_DATA_TYPE_VARBINARY, common.TSDB_DATA_TYPE_GEOMETRY: + data := iter.ReadStringAsSlice() + if len(data)%2 != 0 { + iter.ReportError("read varbinary", fmt.Sprintf("invalid length %s", string(data))) + } + value := make([]byte, len(data)/2) + for i := 0; i < len(data); i += 2 { + value[i/2] = hexCharToDigit(data[i])<<4 | hexCharToDigit(data[i+1]) + } + row[column] = value default: row[column] = nil iter.Skip() @@ -385,3 +403,14 @@ func lower(b byte) byte { } return b } + +func hexCharToDigit(char byte) uint8 { + switch { + case char >= '0' && char <= '9': + return char - '0' + case char >= 'a' && char <= 'f': + return char - 'a' + 10 + default: + panic("assertion failed: invalid hex char") + } +} diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index d471b22..eac38db 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -1,9 +1,24 @@ package taosRestful import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + crand "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "database/sql" + "encoding/pem" "fmt" + "log" + "math/big" "math/rand" + "net/http" + "net/http/httputil" + "net/url" + "reflect" + "strings" "testing" "time" @@ -11,11 +26,78 @@ import ( "github.com/taosdata/driver-go/v3/types" ) +func generateCreateTableSql(db string, withJson bool) string { + createSql := fmt.Sprintf("create table if not exists %s.alltype(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(100),"+ + "c15 geometry(100)"+ + ")", + db) + if withJson { + createSql += " tags(t json)" + } + return createSql +} + +func generateValues() (value []interface{}, scanValue []interface{}, insertSql string) { + rand.Seed(time.Now().UnixNano()) + v1 := true + v2 := int8(rand.Int()) + v3 := int16(rand.Int()) + v4 := rand.Int31() + v5 := int64(rand.Int31()) + v6 := uint8(rand.Uint32()) + v7 := uint16(rand.Uint32()) + v8 := rand.Uint32() + v9 := uint64(rand.Uint32()) + v10 := rand.Float32() + v11 := rand.Float64() + v12 := "test_binary" + v13 := "test_nchar" + v14 := []byte("test_varbinary") + v15 := []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40} + ts := time.Now().Round(time.Millisecond).UTC() + var ( + cts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + c14 []byte + c15 []byte + ) + return []interface{}{ + ts, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + }, []interface{}{cts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15}, + fmt.Sprintf(`values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar','test_varbinary','point(100 100)')`, ts.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +} + // @author: xftan // @date: 2021/12/21 10:59 // @description: test restful query of all type func TestAllTypeQuery(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -26,57 +108,25 @@ func TestAllTypeQuery(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test.t1 using restful_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -84,71 +134,27 @@ func TestAllTypeQuery(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2022/2/8 12:51 // @description: test query all null value func TestAllTypeQueryNull(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_null" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -159,42 +165,29 @@ func TestAllTypeQueryNull(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists restful_test_null.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_null.t1 using restful_test_null.alltype tags('null') values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') values('%s'%s)`, database, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String())) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_null.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -202,72 +195,32 @@ func TestAllTypeQueryNull(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } + var tt types.RawMessage + values[len(colValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - tt *types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) - assert.Equal(t, types.RawMessage("null"), *tt) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), *(values[len(values)-1]).(*types.RawMessage)) } // @author: xftan // @date: 2022/2/10 14:32 // @description: test restful query of all type with compression func TestAllTypeQueryCompression(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_compression" db, err := sql.Open("taosRestful", dataSourceNameWithCompression) if err != nil { t.Fatal(err) @@ -278,57 +231,25 @@ func TestAllTypeQueryCompression(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test.t1 using restful_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -336,70 +257,27 @@ func TestAllTypeQueryCompression(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2022/5/19 15:22 // @description: test restful query of all type without json (httpd) func TestAllTypeQueryWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_without_json" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -410,56 +288,25 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_without_json") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_without_json") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = false - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test_without_json.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_without_json.all_type values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.alltype %s`, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_without_json.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -467,60 +314,16 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + dest := make([]interface{}, len(scanValues)) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } } @@ -528,7 +331,7 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { // @date: 2022/5/19 15:22 // @description: test query all null value without json (httpd) func TestAllTypeQueryNullWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_without_json_null" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -539,41 +342,30 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists restful_test_without_json_null.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_without_json_null.all_type values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + insertSql := fmt.Sprintf(`insert into %s.alltype values('%s'%s)`, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String()) + _, err = db.Exec(insertSql) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_without_json_null.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -581,60 +373,160 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) + } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) + } +} + +func generateSelfSignedCert() (tls.Certificate, error) { + priv, err := ecdsa.GenerateKey(elliptic.P384(), crand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + + serialNumber, err := crand.Int(crand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Your Company"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(crand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return tls.Certificate{}, err + } + + keyPEMBlock := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyPEM}) + + return tls.X509KeyPair(certPEM, keyPEMBlock) +} + +func startProxy() *http.Server { + // Generate self-signed certificate + cert, err := generateSelfSignedCert() + if err != nil { + log.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + target := "http://127.0.0.1:6041" + proxyURL, err := url.Parse(target) + if err != nil { + log.Fatalf("Failed to parse target URL: %v", err) + } + + proxy := httputil.NewSingleHostReverseProxy(proxyURL) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, e error) { + http.Error(w, "Proxy error", http.StatusBadGateway) + } + mux := http.NewServeMux() + mux.Handle("/", proxy) + + server := &http.Server{ + Addr: ":34443", + Handler: mux, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{cert}}, + // Setup server timeouts for better handling of idle connections and slowloris attacks + WriteTimeout: 10 * time.Second, + ReadTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, + } + + log.Println("Starting server on :34443") + go func() { + err = server.ListenAndServeTLS("", "") + if err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start HTTPS server: %v", err) } + }() + return server +} +func TestSSL(t *testing.T) { + dataSourceNameWithSkipVerify := fmt.Sprintf("%s:%s@https(%s:%d)/?skipVerify=true", user, password, host, 34443) + server := startProxy() + defer server.Shutdown(context.Background()) + time.Sleep(1 * time.Second) + database := "restful_test_ssl" + db, err := sql.Open("taosRestful", dataSourceNameWithSkipVerify) + if err != nil { + t.Fatal(err) + } + defer db.Close() + err = db.Ping() + if err != nil { + t.Fatal(err) + } + defer func() { + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) + if err != nil { + t.Fatal(err) + } + }() + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) + if err != nil { + t.Fatal(err) + } + _, err = db.Exec(generateCreateTableSql(database, true)) + if err != nil { + t.Fatal(err) + } + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) + if err != nil { + t.Fatal(err) + } + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) + assert.NoError(t, err) + columns, err := rows.Columns() + assert.NoError(t, err) + t.Log(columns) + cTypes, err := rows.ColumnTypes() + assert.NoError(t, err) + t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt + for rows.Next() { + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } diff --git a/taosRestful/dsn.go b/taosRestful/dsn.go index 3936231..612962a 100644 --- a/taosRestful/dsn.go +++ b/taosRestful/dsn.go @@ -30,6 +30,7 @@ type config struct { disableCompression bool readBufferSize int token string // cloud platform token + skipVerify bool } // NewConfig creates a new Config and sets default values. @@ -154,6 +155,11 @@ func parseDSNParams(cfg *config, params string) (err error) { } case "token": cfg.token = value + case "skipVerify": + cfg.skipVerify, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid bool value: " + value} + } default: // lazy init if cfg.params == nil { diff --git a/taosRestful/dsn_test.go b/taosRestful/dsn_test.go index e71b813..2ec3fbd 100644 --- a/taosRestful/dsn_test.go +++ b/taosRestful/dsn_test.go @@ -10,15 +10,16 @@ import ( // @description: test parse dsn func TestParseDsn(t *testing.T) { tcs := []struct { - dsn string - errs string - user string - passwd string - net string - addr string - port int - dbName string - token string + dsn string + errs string + user string + passwd string + net string + addr string + port int + dbName string + token string + skipVerify bool }{{}, {dsn: "abcd", errs: "invalid DSN: missing the slash separating the database name"}, {dsn: "user:passwd@http(fqdn:6041)/dbname", user: "user", passwd: "passwd", net: "http", addr: "fqdn", port: 6041, dbName: "dbname"}, @@ -28,6 +29,7 @@ func TestParseDsn(t *testing.T) { {dsn: "user:passwd@https(:0)/", user: "user", passwd: "passwd", net: "https"}, {dsn: "user:passwd@https(:0)/?interpolateParams=false&test=1", user: "user", passwd: "passwd", net: "https"}, {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token", user: "user", passwd: "passwd", net: "https", token: "token"}, + {dsn: "user:passwd@https(:0)/?interpolateParams=false&token=token&skipVerify=true", user: "user", passwd: "passwd", net: "https", token: "token", skipVerify: true}, } for i, tc := range tcs { name := fmt.Sprintf("%d - %s", i, tc.dsn) @@ -45,7 +47,9 @@ func TestParseDsn(t *testing.T) { cfg.passwd != tc.passwd || cfg.net != tc.net || cfg.addr != tc.addr || - cfg.port != tc.port { + cfg.port != tc.port || + cfg.token != tc.token || + cfg.skipVerify != tc.skipVerify { t.Fatal(cfg) } }) diff --git a/taosSql/rows.go b/taosSql/rows.go index aaf684d..54b61a8 100644 --- a/taosSql/rows.go +++ b/taosSql/rows.go @@ -19,7 +19,6 @@ type rows struct { block unsafe.Pointer blockOffset int blockSize int - lengthList []int result unsafe.Pointer precision int isStmt bool @@ -107,7 +106,6 @@ func (rs *rows) taosFetchBlock() error { } rs.blockSize = result.N rs.block = wrapper.TaosGetRawBlock(result.Res) - rs.lengthList = wrapper.FetchLengths(rs.result, len(rs.rowsHeader.ColLength)) rs.blockOffset = 0 return nil } diff --git a/taosSql/statement.go b/taosSql/statement.go index e103a0e..9513e37 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -138,11 +138,11 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { case reflect.Bool: v.Value = types.TaosBool(rv.Bool()) case reflect.Float32, reflect.Float64: - v.Value = types.TaosBool(rv.Float() == 1) + v.Value = types.TaosBool(rv.Float() > 0) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Value = types.TaosBool(rv.Int() == 1) + v.Value = types.TaosBool(rv.Int() > 0) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Value = types.TaosBool(rv.Uint() == 1) + v.Value = types.TaosBool(rv.Uint() > 0) case reflect.String: vv, err := strconv.ParseBool(rv.String()) if err != nil { diff --git a/taosWS/connection.go b/taosWS/connection.go index ac2e436..c465815 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -7,17 +7,16 @@ import ( "encoding/json" "errors" "fmt" - "github.com/zeromicro/go-zero/core/logx" - "github.com/zeromicro/go-zero/core/syncx" - "github.com/zeromicro/go-zero/core/timex" "net/url" "strings" + "sync" "sync/atomic" "time" "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -29,16 +28,20 @@ const ( WSFetch = "fetch" WSFetchBlock = "fetch_block" WSFreeResult = "free_result" -) - -const defaultSlowThreshold = time.Millisecond * 500 -var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold) + STMTInit = "init" + STMTPrepare = "prepare" + STMTAddBatch = "add_batch" + STMTExec = "exec" + STMTClose = "close" + STMTGetColFields = "get_col_fields" + STMTUseResult = "use_result" +) -// SetSlowThreshold sets the slow threshold. -func SetSlowThreshold(threshold time.Duration) { - slowThreshold.Set(threshold) -} +const ( + BinaryQueryMessage uint64 = 6 + FetchRawBlockMessage uint64 = 7 +) var ( NotQueryError = errors.New("sql is an update statement not a query statement") @@ -49,31 +52,44 @@ type taosConn struct { buf *bytes.Buffer client *websocket.Conn requestID uint64 + writeLock sync.Mutex readTimeout time.Duration writeTimeout time.Duration cfg *config + messageChan chan *message + messageError error endpoint string + closed uint32 + closeCh chan struct{} +} + +type message struct { + mt int + message []byte + err error } func (tc *taosConn) generateReqID() uint64 { return atomic.AddUint64(&tc.requestID, 1) } -func newTaosConn(ctx context.Context, cfg *config) (*taosConn, error) { +func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl := &url.URL{ Scheme: cfg.net, Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), - Path: "/rest/ws", + Path: "/ws", } if cfg.token != "" { endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) } endpoint := endpointUrl.String() - ws, _, err := common.DefaultDialer.Dial(endpoint, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = cfg.enableCompression + ws, _, err := dialer.Dial(endpoint, nil) if err != nil { return nil, err } - ws.SetReadLimit(common.BufferSize4M) + ws.EnableWriteCompression(cfg.enableCompression) ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) @@ -87,22 +103,59 @@ func newTaosConn(ctx context.Context, cfg *config) (*taosConn, error) { writeTimeout: cfg.writeTimeout, cfg: cfg, endpoint: endpoint, + closeCh: make(chan struct{}), + messageChan: make(chan *message, 10), } + go tc.ping() + go tc.read() err = tc.connect() if err != nil { - logx.WithContext(ctx).Errorf("websocket 连接失败,err:%v", err) tc.Close() - return nil, err } return tc, nil } +func (tc *taosConn) ping() { + ticker := time.NewTicker(common.DefaultPingPeriod) + defer ticker.Stop() + for { + select { + case <-tc.closeCh: + return + case <-ticker.C: + tc.writePing() + } + } +} + +func (tc *taosConn) read() { + for { + mt, msg, err := tc.client.ReadMessage() + tc.messageChan <- &message{ + mt: mt, + message: msg, + err: err, + } + if err != nil { + tc.messageError = NewBadConnError(err) + break + } + if tc.isClosed() { + break + } + } +} + func (tc *taosConn) Begin() (driver.Tx, error) { return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support transaction"} } func (tc *taosConn) Close() (err error) { + if !tc.isClosed() { + atomic.StoreUint32(&tc.closed, 1) + close(tc.closeCh) + } if tc.client != nil { err = tc.client.Close() } @@ -112,51 +165,138 @@ func (tc *taosConn) Close() (err error) { return err } +func (tc *taosConn) isClosed() bool { + return atomic.LoadUint32(&tc.closed) != 0 +} + func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { - return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support stmt"} + if tc.isClosed() { + return nil, driver.ErrBadConn + } + stmtID, err := tc.stmtInit() + if err != nil { + return nil, err + } + isInsert, err := tc.stmtPrepare(stmtID, query) + if err != nil { + tc.stmtClose(stmtID) + return nil, err + } + stmt := &Stmt{ + conn: tc, + stmtID: stmtID, + isInsert: isInsert, + pSql: query, + } + return stmt, nil } -func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +func (tc *taosConn) stmtInit() (uint64, error) { + reqID := tc.generateReqID() + req := &StmtInitReq{ + ReqID: reqID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTInit, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtInitResp + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.StmtID, nil } -func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { - return tc.execCtx(ctx, query, args) +func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { + reqID := tc.generateReqID() + req := &StmtPrepareRequest{ + ReqID: reqID, + StmtID: stmtID, + SQL: sql, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return false, err + } + action := &WSAction{ + Action: STMTPrepare, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return false, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return false, err + } + var resp StmtPrepareResponse + err = tc.readTo(&resp) + if err != nil { + return false, err + } + if resp.Code != 0 { + return false, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.IsInsert, nil } -func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - if len(args) != 0 { - if !tc.cfg.interpolateParams { - return nil, driver.ErrSkip - } - // try to interpolate the parameters to save extra round trips for preparing and closing a statement - prepared, err := common.InterpolateParams(query, args) - if err != nil { - return nil, err - } - query = prepared +func (tc *taosConn) stmtClose(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtCloseRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTClose, + Args: reqArgs, } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + return nil +} +func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) { reqID := tc.generateReqID() - req := &WSQueryReq{ - ReqID: reqID, - SQL: query, - } - startTime := timex.Now() - duration := timex.Since(startTime) - defer func() { - if duration > slowThreshold.Load() { - logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] taosWsQuery reqID:%v slowcall query: %s", reqID, query) - } else { - logx.WithContext(ctx).WithDuration(duration).Infof("[SQL] taosWsQuery reqID:%v query: %s", reqID, query) - } - }() + req := &StmtGetColFieldsRequest{ + ReqID: reqID, + StmtID: stmtID, + } reqArgs, err := json.Marshal(req) if err != nil { return nil, err } action := &WSAction{ - Action: WSQuery, + Action: STMTGetColFields, Args: reqArgs, } tc.buf.Reset() @@ -168,7 +308,7 @@ func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.Nam if err != nil { return nil, err } - var resp WSQueryResp + var resp StmtGetColFieldsResponse err = tc.readTo(&resp) if err != nil { return nil, err @@ -176,40 +316,134 @@ func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.Nam if resp.Code != 0 { return nil, taosErrors.NewError(resp.Code, resp.Message) } - return driver.RowsAffected(resp.AffectedRows), nil + return resp.Fields, nil } -func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { - return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { + reqID := tc.generateReqID() + tc.buf.Reset() + WriteUint64(tc.buf, reqID) + WriteUint64(tc.buf, stmtID) + WriteUint64(tc.buf, BindMessage) + tc.buf.Write(block) + err := tc.writeBinary(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtBindResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil } -func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { - return tc.queryCtx(ctx, query, args) +func WriteUint64(buffer *bytes.Buffer, v uint64) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) + buffer.WriteByte(byte(v >> 32)) + buffer.WriteByte(byte(v >> 40)) + buffer.WriteByte(byte(v >> 48)) + buffer.WriteByte(byte(v >> 56)) } -func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - if len(args) != 0 { - if !tc.cfg.interpolateParams { - return nil, driver.ErrSkip - } - // try client-side prepare to reduce round trip - prepared, err := common.InterpolateParams(query, args) - if err != nil { - return nil, err - } - query = prepared +func WriteUint32(buffer *bytes.Buffer, v uint32) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) +} + +func WriteUint16(buffer *bytes.Buffer, v uint16) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) +} + +func (tc *taosConn) stmtAddBatch(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtAddBatchRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTAddBatch, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err } + var resp StmtAddBatchResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { reqID := tc.generateReqID() - req := &WSQueryReq{ - ReqID: reqID, - SQL: query, + req := &StmtExecRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTExec, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtExecResponse + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Affected, nil +} + +func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { + reqID := tc.generateReqID() + req := &StmtUseResultRequest{ + ReqID: reqID, + StmtID: stmtID, } reqArgs, err := json.Marshal(req) if err != nil { return nil, err } action := &WSAction{ - Action: WSQuery, + Action: STMTUseResult, Args: reqArgs, } tc.buf.Reset() @@ -221,7 +455,7 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name if err != nil { return nil, err } - var resp WSQueryResp + var resp StmtUseResultResponse err = tc.readTo(&resp) if err != nil { return nil, err @@ -229,6 +463,54 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name if resp.Code != 0 { return nil, taosErrors.NewError(resp.Code, resp.Message) } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + isStmt: true, + } + return rs, nil +} +func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + return tc.execCtx(ctx, query, args) +} + +func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + resp, err := tc.doQuery(ctx, query, args) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return driver.RowsAffected(resp.AffectedRows), nil +} + +func (tc *taosConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return tc.QueryContext(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) +} + +func (tc *taosConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + return tc.queryCtx(ctx, query, args) +} + +func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + resp, err := tc.doQuery(ctx, query, args) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } if resp.IsUpdate { return nil, NotQueryError } @@ -245,8 +527,47 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name return rs, err } +func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) { + if tc.isClosed() { + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !tc.cfg.interpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce round trip + prepared, err := common.InterpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + reqID := tc.generateReqID() + tc.buf.Reset() + + WriteUint64(tc.buf, reqID) // req id + WriteUint64(tc.buf, 0) // message id + WriteUint64(tc.buf, BinaryQueryMessage) + WriteUint16(tc.buf, 1) // version + WriteUint32(tc.buf, uint32(len(query))) // sql length + tc.buf.WriteString(query) + err := tc.writeBinary(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp WSQueryResp + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + return &resp, nil +} + func (tc *taosConn) Ping(ctx context.Context) (err error) { - return nil + if tc.isClosed() { + return driver.ErrBadConn + } + return tc.writePing() } func (tc *taosConn) connect() error { @@ -285,8 +606,28 @@ func (tc *taosConn) connect() error { } func (tc *taosConn) writeText(data []byte) error { + return tc.write(websocket.TextMessage, data) +} + +func (tc *taosConn) writeBinary(data []byte) error { + return tc.write(websocket.BinaryMessage, data) +} + +func (tc *taosConn) writePing() error { + return tc.write(websocket.PingMessage, nil) +} + +func (tc *taosConn) write(messageType int, data []byte) error { + tc.writeLock.Lock() + defer tc.writeLock.Unlock() + if tc.isClosed() { + return driver.ErrBadConn + } + if tc.messageError != nil { + return tc.messageError + } tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) - err := tc.client.WriteMessage(websocket.TextMessage, data) + err := tc.client.WriteMessage(messageType, data) if err != nil { return NewBadConnErrorWithCtx(err, string(data)) } @@ -294,63 +635,50 @@ func (tc *taosConn) writeText(data []byte) error { } func (tc *taosConn) readTo(to interface{}) error { - var outErr error - done := make(chan struct{}) - go func() { - defer func() { - close(done) - }() - mt, respBytes, err := tc.client.ReadMessage() - if err != nil { - outErr = NewBadConnError(err) - return - } - if mt != websocket.TextMessage { - outErr = NewBadConnErrorWithCtx(fmt.Errorf("readTo: got wrong message type %d", mt), formatBytes(respBytes)) - return - } - err = jsonI.Unmarshal(respBytes, to) - if err != nil { - outErr = NewBadConnErrorWithCtx(err, string(respBytes)) - return - } - }() - ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) - defer cancel() - select { - case <-done: - return outErr - case <-ctx.Done(): - return NewBadConnError(ReadTimeoutError) + mt, respBytes, err := tc.readResponse() + if err != nil { + return err } + if mt != websocket.TextMessage { + return NewBadConnErrorWithCtx(fmt.Errorf("readTo: got wrong message type %d", mt), formatBytes(respBytes)) + } + err = jsonI.Unmarshal(respBytes, to) + if err != nil { + return NewBadConnErrorWithCtx(err, string(respBytes)) + } + return nil } func (tc *taosConn) readBytes() ([]byte, error) { - var respBytes []byte - var outErr error - done := make(chan struct{}) - go func() { - defer func() { - close(done) - }() - mt, message, err := tc.client.ReadMessage() - if err != nil { - outErr = NewBadConnError(err) - return - } - if mt != websocket.BinaryMessage { - outErr = NewBadConnErrorWithCtx(fmt.Errorf("readBytes: got wrong message type %d", mt), string(respBytes)) - return - } - respBytes = message - }() + mt, respBytes, err := tc.readResponse() + if err != nil { + return nil, err + } + if mt != websocket.BinaryMessage { + return nil, NewBadConnErrorWithCtx(fmt.Errorf("readBytes: got wrong message type %d", mt), string(respBytes)) + } + return respBytes, err +} + +func (tc *taosConn) readResponse() (int, []byte, error) { + if tc.isClosed() { + return 0, nil, driver.ErrBadConn + } + if tc.messageError != nil { + return 0, nil, tc.messageError + } ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) defer cancel() select { - case <-done: - return respBytes, outErr + case <-tc.closeCh: + return 0, nil, driver.ErrBadConn + case msg := <-tc.messageChan: + if msg.err != nil { + return 0, nil, NewBadConnError(msg.err) + } + return msg.mt, msg.message, nil case <-ctx.Done(): - return nil, NewBadConnError(ReadTimeoutError) + return 0, nil, NewBadConnError(ReadTimeoutError) } } diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index 6aee030..e93d710 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -46,3 +46,29 @@ func Test_formatBytes(t *testing.T) { }) } } + +func TestBadConnection(t *testing.T) { + defer func() { + if r := recover(); r != nil { + // bad connection should not panic + t.Fatalf("panic: %v", r) + } + }() + + cfg, err := parseDSN(dataSourceName) + if err != nil { + t.Fatalf("parseDSN error: %v", err) + } + conn, err := newTaosConn(cfg) + if err != nil { + t.Fatalf("newTaosConn error: %v", err) + } + + // to test bad connection, we manually close the connection + conn.Close() + + _, err = conn.Query("select 1", nil) + if err == nil { + t.Fatalf("query should fail") + } +} diff --git a/taosWS/connector_test.go b/taosWS/connector_test.go index a40a251..d971a1d 100644 --- a/taosWS/connector_test.go +++ b/taosWS/connector_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "math/rand" + "reflect" + "strings" "testing" "time" @@ -11,11 +13,78 @@ import ( "github.com/taosdata/driver-go/v3/types" ) +func generateCreateTableSql(db string, withJson bool) string { + createSql := fmt.Sprintf("create table if not exists %s.alltype(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(100),"+ + "c15 geometry(100)"+ + ")", + db) + if withJson { + createSql += " tags(t json)" + } + return createSql +} + +func generateValues() (value []interface{}, scanValue []interface{}, insertSql string) { + rand.Seed(time.Now().UnixNano()) + v1 := true + v2 := int8(rand.Int()) + v3 := int16(rand.Int()) + v4 := rand.Int31() + v5 := int64(rand.Int31()) + v6 := uint8(rand.Uint32()) + v7 := uint16(rand.Uint32()) + v8 := rand.Uint32() + v9 := uint64(rand.Uint32()) + v10 := rand.Float32() + v11 := rand.Float64() + v12 := "test_binary" + v13 := "test_nchar" + v14 := []byte("test_varbinary") + v15 := []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40} + ts := time.Now().Round(time.Millisecond) + var ( + cts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + c14 []byte + c15 []byte + ) + return []interface{}{ + ts, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + }, []interface{}{cts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15}, + fmt.Sprintf(`values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar','test_varbinary','point(100 100)')`, ts.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +} + // @author: xftan // @date: 2023/10/13 11:22 // @description: test all type query func TestAllTypeQuery(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -26,57 +95,25 @@ func TestAllTypeQuery(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -84,71 +121,27 @@ func TestAllTypeQuery(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2023/10/13 11:22 // @description: test null value func TestAllTypeQueryNull(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_null" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -159,42 +152,29 @@ func TestAllTypeQueryNull(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists ws_test_null.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_null.t1 using ws_test_null.alltype tags('null') values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') values('%s'%s)`, database, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String())) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_null.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -202,72 +182,32 @@ func TestAllTypeQueryNull(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } + var tt types.RawMessage + values[len(colValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - tt *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) - assert.Nil(t, tt) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), *(values[len(values)-1]).(*types.RawMessage)) } // @author: xftan // @date: 2023/10/13 11:24 // @description: test compression func TestAllTypeQueryCompression(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_compression" db, err := sql.Open("taosWS", dataSourceNameWithCompression) if err != nil { t.Fatal(err) @@ -278,57 +218,25 @@ func TestAllTypeQueryCompression(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -336,70 +244,27 @@ func TestAllTypeQueryCompression(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2023/10/13 11:24 // @description: test all type query without json func TestAllTypeQueryWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_without_json" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -410,56 +275,25 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_without_json") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_without_json") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = false - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test_without_json.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json.all_type values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.alltype %s`, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -467,60 +301,16 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + dest := make([]interface{}, len(scanValues)) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } } @@ -528,7 +318,7 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { // @date: 2023/10/13 11:24 // @description: test all type query with null without json func TestAllTypeQueryNullWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_without_json_null" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -539,41 +329,30 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists ws_test_without_json_null.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json_null.all_type values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + insertSql := fmt.Sprintf(`insert into %s.alltype values('%s'%s)`, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String()) + _, err = db.Exec(insertSql) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json_null.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -581,61 +360,21 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - + } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) } } diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index 9e12d7a..56b70e8 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -34,7 +34,7 @@ var ( port = 6041 dbName = "test_taos_ws" dataSourceName = fmt.Sprintf("%s:%s@ws(%s:%d)/", user, password, host, port) - dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?disableCompression=false", user, password, host, port) + dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?enableCompression=true", user, password, host, port) ) type DBTest struct { diff --git a/taosWS/dsn.go b/taosWS/dsn.go index ca5ff22..aa8e4dd 100644 --- a/taosWS/dsn.go +++ b/taosWS/dsn.go @@ -29,6 +29,7 @@ type config struct { params map[string]string // Connection parameters interpolateParams bool // Interpolate placeholders into query string token string // cloud platform token + enableCompression bool // Enable write compression readTimeout time.Duration // read message timeout writeTimeout time.Duration // write message timeout } @@ -143,6 +144,11 @@ func parseDSNParams(cfg *config, params string) (err error) { } case "token": cfg.token = value + case "enableCompression": + cfg.enableCompression, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid enableCompression value: " + value} + } case "readTimeout": cfg.readTimeout, err = time.ParseDuration(value) if err != nil { diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index edfd013..7d7c7d7 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -25,6 +25,15 @@ func TestParseDsn(t *testing.T) { {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, + {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &config{ + user: "user", + passwd: "passwd", + net: "wss", + readTimeout: 10 * time.Minute, + writeTimeout: 8 * time.Second, + interpolateParams: true, + enableCompression: true, + }}, } for _, tc := range tests { t.Run(tc.dsn, func(t *testing.T) { diff --git a/taosWS/proto.go b/taosWS/proto.go index fd2fb39..2731eec 100644 --- a/taosWS/proto.go +++ b/taosWS/proto.go @@ -1,6 +1,10 @@ package taosWS -import "encoding/json" +import ( + "encoding/json" + + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" +) type WSConnectReq struct { ReqID uint64 `json:"req_id"` @@ -69,3 +73,122 @@ type WSAction struct { Action string `json:"action"` Args json.RawMessage `json:"args"` } + +type StmtPrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} + +type StmtPrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` +} + +type StmtInitReq struct { + ReqID uint64 `json:"req_id"` +} + +type StmtInitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} +type StmtCloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtCloseResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id,omitempty"` +} + +type StmtGetColFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtGetColFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields"` +} + +const ( + BindMessage = 2 +) + +type StmtBindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +type StmtUseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtUseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} diff --git a/taosWS/rows.go b/taosWS/rows.go index b462f4e..636f54c 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -3,14 +3,15 @@ package taosWS import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/json" + "fmt" "io" "reflect" "unsafe" "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/common/pointer" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -27,6 +28,7 @@ type rows struct { fieldsTypes []uint8 fieldsLengths []int64 precision int + isStmt bool } func (rs *rows) Columns() []string { @@ -85,75 +87,53 @@ func (rs *rows) Next(dest []driver.Value) error { func (rs *rows) taosFetchBlock() error { reqID := rs.conn.generateReqID() - req := &WSFetchReq{ - ReqID: reqID, - ID: rs.resultID, - } - args, err := json.Marshal(req) - if err != nil { - return err - } - action := &WSAction{ - Action: WSFetch, - Args: args, - } rs.buf.Reset() - - err = jsonI.NewEncoder(rs.buf).Encode(action) + WriteUint64(rs.buf, reqID) // req id + WriteUint64(rs.buf, rs.resultID) // message id + WriteUint64(rs.buf, FetchRawBlockMessage) + WriteUint16(rs.buf, 1) // version + err := rs.conn.writeBinary(rs.buf.Bytes()) if err != nil { return err } - err = rs.conn.writeText(rs.buf.Bytes()) + respBytes, err := rs.conn.readBytes() if err != nil { return err } - var resp WSFetchResp - err = rs.conn.readTo(&resp) - if err != nil { - return err + if len(respBytes) < 51 { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) + version := binary.LittleEndian.Uint16(respBytes[16:]) + if version != 1 { + return taosErrors.NewError(0xffff, fmt.Sprintf("unsupported fetch raw block version: %d", version)) } - if resp.Completed { + code := binary.LittleEndian.Uint32(respBytes[34:]) + msgLen := int(binary.LittleEndian.Uint32(respBytes[38:])) + if len(respBytes) < 51+msgLen { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + errMsg := string(respBytes[42 : 42+msgLen]) + if code != 0 { + return taosErrors.NewError(int(code), errMsg) + } + completed := respBytes[50+msgLen] == 1 + if completed { rs.blockSize = 0 return nil } else { - rs.blockSize = resp.Rows - return rs.fetchBlock() - } -} - -func (rs *rows) fetchBlock() error { - reqID := rs.conn.generateReqID() - req := &WSFetchBlockReq{ - ReqID: reqID, - ID: rs.resultID, - } - args, err := json.Marshal(req) - if err != nil { - return err - } - action := &WSAction{ - Action: WSFetchBlock, - Args: args, - } - rs.buf.Reset() - err = jsonI.NewEncoder(rs.buf).Encode(action) - if err != nil { - return err - } - err = rs.conn.writeText(rs.buf.Bytes()) - if err != nil { - return err - } - respBytes, err := rs.conn.readBytes() - if err != nil { - return err + if len(respBytes) < 55+msgLen { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + blockLength := binary.LittleEndian.Uint32(respBytes[51+msgLen:]) + if len(respBytes) < 55+msgLen+int(blockLength) { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + rawBlock := respBytes[55+msgLen : 55+msgLen+int(blockLength)] + rs.block = rawBlock + rs.blockPtr = unsafe.Pointer(&rs.block[0]) + rs.blockSize = int(parser.RawBlockGetNumOfRows(rs.blockPtr)) + rs.blockOffset = 0 } - rs.block = respBytes - rs.blockPtr = pointer.AddUintptr(unsafe.Pointer(&rs.block[0]), 16) - rs.blockOffset = 0 return nil } @@ -177,5 +157,5 @@ func (rs *rows) freeResult() error { if err != nil { return err } - return nil + return tc.writeText(rs.buf.Bytes()) } diff --git a/taosWS/statement.go b/taosWS/statement.go new file mode 100644 index 0000000..11f1f7e --- /dev/null +++ b/taosWS/statement.go @@ -0,0 +1,520 @@ +package taosWS + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/driver-go/v3/types" +) + +type Stmt struct { + stmtID uint64 + conn *taosConn + buffer bytes.Buffer + pSql string + isInsert bool + cols []*stmtCommon.StmtField + colTypes *param.ColumnType + queryColTypes []*types.ColumnType +} + +func (stmt *Stmt) Close() error { + if stmt.conn == nil || stmt.conn.isClosed() || stmt.conn.messageError != nil { + return driver.ErrBadConn + } + err := stmt.conn.stmtClose(stmt.stmtID) + stmt.buffer.Reset() + stmt.conn = nil + return err +} + +func (stmt *Stmt) NumInput() int { + if stmt.colTypes != nil { + return len(stmt.cols) + } + return -1 +} + +func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.conn.isClosed() { + return nil, driver.ErrBadConn + } + if len(args) != len(stmt.cols) { + return nil, fmt.Errorf("stmt exec error: wrong number of parameters") + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), stmt.colTypes) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + affected, err := stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return driver.RowsAffected(affected), nil +} + +func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.conn.isClosed() { + return nil, driver.ErrBadConn + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), param.NewColumnTypeWithValue(stmt.queryColTypes)) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + _, err = stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return stmt.conn.stmtUseResult(stmt.stmtID) +} + +func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { + if stmt.isInsert { + if stmt.cols == nil { + cols, err := stmt.conn.stmtGetColFields(stmt.stmtID) + if err != nil { + return err + } + colTypes := make([]*types.ColumnType, len(cols)) + for i, col := range cols { + t, err := col.GetType() + if err != nil { + return err + } + colTypes[i] = t + } + stmt.cols = cols + stmt.colTypes = param.NewColumnTypeWithValue(colTypes) + } + if v.Ordinal > len(stmt.cols) { + return nil + } + if v.Value == nil { + return nil + } + switch stmt.cols[v.Ordinal-1].FieldType { + case common.TSDB_DATA_TYPE_NULL: + v.Value = nil + case common.TSDB_DATA_TYPE_BOOL: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBool(rv.Float() > 0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBool(rv.Int() > 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBool(rv.Uint() > 0) + case reflect.String: + vv, err := strconv.ParseBool(rv.String()) + if err != nil { + return err + } + v.Value = types.TaosBool(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bool", v) + } + case common.TSDB_DATA_TYPE_TINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosTinyint(1) + } else { + v.Value = types.TaosTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint", v) + } + case common.TSDB_DATA_TYPE_SMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosSmallint(1) + } else { + v.Value = types.TaosSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint", v) + } + case common.TSDB_DATA_TYPE_INT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosInt(1) + } else { + v.Value = types.TaosInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int", v) + } + case common.TSDB_DATA_TYPE_BIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosBigint(1) + } else { + v.Value = types.TaosBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint", v) + } + case common.TSDB_DATA_TYPE_FLOAT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosFloat(1) + } else { + v.Value = types.TaosFloat(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosFloat(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosFloat(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosFloat(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 32) + if err != nil { + return err + } + v.Value = types.TaosFloat(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to float", v) + } + case common.TSDB_DATA_TYPE_DOUBLE: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosDouble(1) + } else { + v.Value = types.TaosDouble(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosDouble(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosDouble(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 64) + if err != nil { + return err + } + v.Value = types.TaosDouble(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to double", v) + } + case common.TSDB_DATA_TYPE_BINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to binary", v) + } + case common.TSDB_DATA_TYPE_VARBINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosVarBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosVarBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to varbinary", v) + } + + case common.TSDB_DATA_TYPE_GEOMETRY: + switch v.Value.(type) { + case string: + v.Value = types.TaosGeometry(v.Value.(string)) + case []byte: + v.Value = types.TaosGeometry(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to geometry", v) + } + + case common.TSDB_DATA_TYPE_TIMESTAMP: + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Float32, reflect.Float64: + t := common.TimestampConvertToTime(int64(rv.Float()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + t := common.TimestampConvertToTime(rv.Int(), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t := common.TimestampConvertToTime(int64(rv.Uint()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.String: + t, err := time.Parse(time.RFC3339Nano, rv.String()) + if err != nil { + return err + } + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to timestamp", v) + } + case common.TSDB_DATA_TYPE_NCHAR: + switch v.Value.(type) { + case string: + v.Value = types.TaosNchar(v.Value.(string)) + case []byte: + v.Value = types.TaosNchar(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to nchar", v) + } + case common.TSDB_DATA_TYPE_UTINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUTinyint(1) + } else { + v.Value = types.TaosUTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosUTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint unsigned", v) + } + case common.TSDB_DATA_TYPE_USMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUSmallint(1) + } else { + v.Value = types.TaosUSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosUSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint unsigned", v) + } + case common.TSDB_DATA_TYPE_UINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUInt(1) + } else { + v.Value = types.TaosUInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosUInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int unsigned", v) + } + case common.TSDB_DATA_TYPE_UBIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUBigint(1) + } else { + v.Value = types.TaosUBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosUBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint unsigned", v) + } + } + return nil + } else { + if v.Value == nil { + return errors.New("CheckNamedValue: value is nil") + } + if v.Ordinal == 1 { + stmt.queryColTypes = nil + } + if len(stmt.queryColTypes) < v.Ordinal { + tmp := stmt.queryColTypes + stmt.queryColTypes = make([]*types.ColumnType, v.Ordinal) + copy(stmt.queryColTypes, tmp) + } + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosBinary(t.Format(time.RFC3339Nano)) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBinaryType} + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBoolType} + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosDoubleType} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBigintType} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosUBigintType} + case reflect.String: + strVal := rv.String() + v.Value = types.TaosBinary(strVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(strVal), + } + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + bsVal := rv.Bytes() + v.Value = types.TaosBinary(bsVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(bsVal), + } + } else { + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + default: + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + return nil + } +} diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go new file mode 100644 index 0000000..1ab008c --- /dev/null +++ b/taosWS/statement_test.go @@ -0,0 +1,2159 @@ +package taosWS + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStmtExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer func() { + t.Log("start3") + db.Close() + t.Log("done3") + }() + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + t.Log("done2") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), affected) + t.Log("done") +} + +func TestStmtQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver_ws_q") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws_q.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + stmt.Close() + stmt, err = db.Prepare("select * from test_stmt_driver_ws_q.ct where ts = ?") + if err != nil { + t.Error(err) + return + } + rows, err := stmt.Query(now) + if err != nil { + t.Error(err) + return + } + columns, err := rows.Columns() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + assert.NoError(t, err) + assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) + assert.Equal(t, true, c1) + assert.Equal(t, int8(2), c2) + assert.Equal(t, int16(3), c3) + assert.Equal(t, int32(4), c4) + assert.Equal(t, int64(5), c5) + assert.Equal(t, uint8(6), c6) + assert.Equal(t, uint16(7), c7) + assert.Equal(t, uint32(8), c8) + assert.Equal(t, uint64(9), c9) + assert.Equal(t, float32(10), c10) + assert.Equal(t, float64(11), c11) + assert.Equal(t, "binary", c12) + assert.Equal(t, "nchar", c13) + } + assert.Equal(t, 1, count) +} + +func TestStmtConvertExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + now := time.Now().Format(time.RFC3339Nano) + tests := []struct { + name string + tbType string + pos string + bind []interface{} + expectValue interface{} + expectError bool + }{ + //bool + { + name: "bool_null", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "bool_err", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, []int{123}}, + expectValue: nil, + expectError: true, + }, + { + name: "bool_bool_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: true, + }, + { + name: "bool_bool_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: false, + }, + { + name: "bool_float_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: true, + }, + { + name: "bool_float_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(0)}, + expectValue: false, + }, + { + name: "bool_int_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(1)}, + expectValue: true, + }, + { + name: "bool_int_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(0)}, + expectValue: false, + }, + { + name: "bool_uint_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(1)}, + expectValue: true, + }, + { + name: "bool_uint_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(0)}, + expectValue: false, + }, + { + name: "bool_string_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "true"}, + expectValue: true, + }, + { + name: "bool_string_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "false"}, + expectValue: false, + }, + //tiny int + { + name: "tiny_nil", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "tiny_err", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "tiny_bool_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int8(1), + }, + { + name: "tiny_bool_0", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int8(0), + }, + { + name: "tiny_float_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int8(1), + }, + { + name: "tiny_int_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int8(1), + }, + { + name: "tiny_uint_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int8(1), + }, + { + name: "tiny_string_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int8(1), + }, + // small int + { + name: "small_nil", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "small_err", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "small_bool_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int16(1), + }, + { + name: "small_bool_0", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int16(0), + }, + { + name: "small_float_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int16(1), + }, + { + name: "small_int_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int16(1), + }, + { + name: "small_uint_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int16(1), + }, + { + name: "small_string_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int16(1), + }, + // int + { + name: "int_nil", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "int_err", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "int_bool_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int32(1), + }, + { + name: "int_bool_0", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int32(0), + }, + { + name: "int_float_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int32(1), + }, + { + name: "int_int_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int32(1), + }, + { + name: "int_uint_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int32(1), + }, + { + name: "int_string_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int32(1), + }, + // big int + { + name: "big_nil", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "big_err", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "big_bool_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int64(1), + }, + { + name: "big_bool_0", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int64(0), + }, + { + name: "big_float_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int64(1), + }, + { + name: "big_int_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int64(1), + }, + { + name: "big_uint_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int64(1), + }, + { + name: "big_string_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int64(1), + }, + // float + { + name: "float_nil", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "float_err", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "float_bool_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float32(1), + }, + { + name: "float_bool_0", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float32(0), + }, + { + name: "float_float_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float32(1), + }, + { + name: "float_int_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float32(1), + }, + { + name: "float_uint_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float32(1), + }, + { + name: "float_string_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float32(1), + }, + //double + { + name: "double_nil", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "double_err", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "double_bool_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float64(1), + }, + { + name: "double_bool_0", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float64(0), + }, + { + name: "double_double_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float64(1), + }, + { + name: "double_int_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float64(1), + }, + { + name: "double_uint_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float64(1), + }, + { + name: "double_string_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float64(1), + }, + + //tiny int unsigned + { + name: "utiny_nil", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "utiny_err", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "utiny_bool_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint8(1), + }, + { + name: "utiny_bool_0", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint8(0), + }, + { + name: "utiny_float_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_int_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_uint_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_string_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint8(1), + }, + // small int unsigned + { + name: "usmall_nil", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "usmall_err", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "usmall_bool_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint16(1), + }, + { + name: "usmall_bool_0", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint16(0), + }, + { + name: "usmall_float_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_int_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_uint_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_string_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint16(1), + }, + // int unsigned + { + name: "uint_nil", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "uint_err", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "uint_bool_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint32(1), + }, + { + name: "uint_bool_0", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint32(0), + }, + { + name: "uint_float_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint32(1), + }, + { + name: "uint_int_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint32(1), + }, + { + name: "uint_uint_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint32(1), + }, + { + name: "uint_string_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint32(1), + }, + // big int unsigned + { + name: "ubig_nil", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ubig_err", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ubig_bool_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint64(1), + }, + { + name: "ubig_bool_0", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint64(0), + }, + { + name: "ubig_float_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_int_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_uint_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_string_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint64(1), + }, + //binary + { + name: "binary_nil", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "binary_err", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + //nchar + { + name: "nchar_nil", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "nchar_err", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + // timestamp + { + name: "ts_nil", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ts_err", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ts_time_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, time.Unix(0, 1e6)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_float_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_int_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_uint_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_string_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, "1970-01-01T00:00:00.001Z"}, + expectValue: time.Unix(0, 1e6), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tbName := fmt.Sprintf("test_%s", tt.name) + tbType := tt.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tt.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if _, err = db.Exec(drop); err != nil { + t.Error(err) + return + } + if _, err = db.Exec(create); err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(tt.bind...) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + rows, err := db.Query(fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + var data []driver.Value + tts, err := rows.ColumnTypes() + if err != nil { + t.Error(err) + return + } + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} + +func TestStmtConvertQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table t0 (ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + now := time.Now() + after1s := now.Add(time.Second) + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + tests := []struct { + name string + field string + where string + bind interface{} + expectNoValue bool + expectValue driver.Value + expectError bool + }{ + //ts + { + name: "ts", + field: "ts", + where: "ts = ?", + bind: now, + expectValue: time.Unix(now.Unix(), int64((now.Nanosecond()/1e6)*1e6)).Local(), + }, + + //bool + { + name: "bool_true", + field: "c1", + where: "c1 = ?", + bind: true, + expectValue: true, + }, + { + name: "bool_false", + field: "c1", + where: "c1 = ?", + bind: false, + expectNoValue: true, + }, + { + name: "tinyint_int8", + field: "c2", + where: "c2 = ?", + bind: int8(2), + expectValue: int8(2), + }, + { + name: "tinyint_iny16", + field: "c2", + where: "c2 = ?", + bind: int16(2), + expectValue: int8(2), + }, + { + name: "tinyint_int32", + field: "c2", + where: "c2 = ?", + bind: int32(2), + expectValue: int8(2), + }, + { + name: "tinyint_int64", + field: "c2", + where: "c2 = ?", + bind: int64(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint8", + field: "c2", + where: "c2 = ?", + bind: uint8(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint16", + field: "c2", + where: "c2 = ?", + bind: uint16(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint32", + field: "c2", + where: "c2 = ?", + bind: uint32(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint64", + field: "c2", + where: "c2 = ?", + bind: uint64(2), + expectValue: int8(2), + }, + { + name: "tinyint_float32", + field: "c2", + where: "c2 = ?", + bind: float32(2), + expectValue: int8(2), + }, + { + name: "tinyint_float64", + field: "c2", + where: "c2 = ?", + bind: float64(2), + expectValue: int8(2), + }, + { + name: "tinyint_int", + field: "c2", + where: "c2 = ?", + bind: int(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint", + field: "c2", + where: "c2 = ?", + bind: uint(2), + expectValue: int8(2), + }, + + // smallint + { + name: "smallint_int8", + field: "c3", + where: "c3 = ?", + bind: int8(3), + expectValue: int16(3), + }, + { + name: "smallint_iny16", + field: "c3", + where: "c3 = ?", + bind: int16(3), + expectValue: int16(3), + }, + { + name: "smallint_int32", + field: "c3", + where: "c3 = ?", + bind: int32(3), + expectValue: int16(3), + }, + { + name: "smallint_int64", + field: "c3", + where: "c3 = ?", + bind: int64(3), + expectValue: int16(3), + }, + { + name: "smallint_uint8", + field: "c3", + where: "c3 = ?", + bind: uint8(3), + expectValue: int16(3), + }, + { + name: "smallint_uint16", + field: "c3", + where: "c3 = ?", + bind: uint16(3), + expectValue: int16(3), + }, + { + name: "smallint_uint32", + field: "c3", + where: "c3 = ?", + bind: uint32(3), + expectValue: int16(3), + }, + { + name: "smallint_uint64", + field: "c3", + where: "c3 = ?", + bind: uint64(3), + expectValue: int16(3), + }, + { + name: "smallint_float32", + field: "c3", + where: "c3 = ?", + bind: float32(3), + expectValue: int16(3), + }, + { + name: "smallint_float64", + field: "c3", + where: "c3 = ?", + bind: float64(3), + expectValue: int16(3), + }, + { + name: "smallint_int", + field: "c3", + where: "c3 = ?", + bind: int(3), + expectValue: int16(3), + }, + { + name: "smallint_uint", + field: "c3", + where: "c3 = ?", + bind: uint(3), + expectValue: int16(3), + }, + + //int + { + name: "int_int8", + field: "c4", + where: "c4 = ?", + bind: int8(4), + expectValue: int32(4), + }, + { + name: "int_iny16", + field: "c4", + where: "c4 = ?", + bind: int16(4), + expectValue: int32(4), + }, + { + name: "int_int32", + field: "c4", + where: "c4 = ?", + bind: int32(4), + expectValue: int32(4), + }, + { + name: "int_int64", + field: "c4", + where: "c4 = ?", + bind: int64(4), + expectValue: int32(4), + }, + { + name: "int_uint8", + field: "c4", + where: "c4 = ?", + bind: uint8(4), + expectValue: int32(4), + }, + { + name: "int_uint16", + field: "c4", + where: "c4 = ?", + bind: uint16(4), + expectValue: int32(4), + }, + { + name: "int_uint32", + field: "c4", + where: "c4 = ?", + bind: uint32(4), + expectValue: int32(4), + }, + { + name: "int_uint64", + field: "c4", + where: "c4 = ?", + bind: uint64(4), + expectValue: int32(4), + }, + { + name: "int_float32", + field: "c4", + where: "c4 = ?", + bind: float32(4), + expectValue: int32(4), + }, + { + name: "int_float64", + field: "c4", + where: "c4 = ?", + bind: float64(4), + expectValue: int32(4), + }, + { + name: "int_int", + field: "c4", + where: "c4 = ?", + bind: int(4), + expectValue: int32(4), + }, + { + name: "int_uint", + field: "c4", + where: "c4 = ?", + bind: uint(4), + expectValue: int32(4), + }, + + //bigint + { + name: "bigint_int8", + field: "c5", + where: "c5 = ?", + bind: int8(5), + expectValue: int64(5), + }, + { + name: "bigint_iny16", + field: "c5", + where: "c5 = ?", + bind: int16(5), + expectValue: int64(5), + }, + { + name: "bigint_int32", + field: "c5", + where: "c5 = ?", + bind: int32(5), + expectValue: int64(5), + }, + { + name: "bigint_int64", + field: "c5", + where: "c5 = ?", + bind: int64(5), + expectValue: int64(5), + }, + { + name: "bigint_uint8", + field: "c5", + where: "c5 = ?", + bind: uint8(5), + expectValue: int64(5), + }, + { + name: "bigint_uint16", + field: "c5", + where: "c5 = ?", + bind: uint16(5), + expectValue: int64(5), + }, + { + name: "bigint_uint32", + field: "c5", + where: "c5 = ?", + bind: uint32(5), + expectValue: int64(5), + }, + { + name: "bigint_uint64", + field: "c5", + where: "c5 = ?", + bind: uint64(5), + expectValue: int64(5), + }, + { + name: "bigint_float32", + field: "c5", + where: "c5 = ?", + bind: float32(5), + expectValue: int64(5), + }, + { + name: "bigint_float64", + field: "c5", + where: "c5 = ?", + bind: float64(5), + expectValue: int64(5), + }, + { + name: "bigint_int", + field: "c5", + where: "c5 = ?", + bind: int(5), + expectValue: int64(5), + }, + { + name: "bigint_uint", + field: "c5", + where: "c5 = ?", + bind: uint(5), + expectValue: int64(5), + }, + + //utinyint + { + name: "utinyint_int8", + field: "c6", + where: "c6 = ?", + bind: int8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_iny16", + field: "c6", + where: "c6 = ?", + bind: int16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int32", + field: "c6", + where: "c6 = ?", + bind: int32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int64", + field: "c6", + where: "c6 = ?", + bind: int64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint8", + field: "c6", + where: "c6 = ?", + bind: uint8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint16", + field: "c6", + where: "c6 = ?", + bind: uint16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint32", + field: "c6", + where: "c6 = ?", + bind: uint32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint64", + field: "c6", + where: "c6 = ?", + bind: uint64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float32", + field: "c6", + where: "c6 = ?", + bind: float32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float64", + field: "c6", + where: "c6 = ?", + bind: float64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int", + field: "c6", + where: "c6 = ?", + bind: int(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint", + field: "c6", + where: "c6 = ?", + bind: uint(6), + expectValue: uint8(6), + }, + + //usmallint + { + name: "usmallint_int8", + field: "c7", + where: "c7 = ?", + bind: int8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_iny16", + field: "c7", + where: "c7 = ?", + bind: int16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int32", + field: "c7", + where: "c7 = ?", + bind: int32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int64", + field: "c7", + where: "c7 = ?", + bind: int64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint8", + field: "c7", + where: "c7 = ?", + bind: uint8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint16", + field: "c7", + where: "c7 = ?", + bind: uint16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint32", + field: "c7", + where: "c7 = ?", + bind: uint32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint64", + field: "c7", + where: "c7 = ?", + bind: uint64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float32", + field: "c7", + where: "c7 = ?", + bind: float32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float64", + field: "c7", + where: "c7 = ?", + bind: float64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int", + field: "c7", + where: "c7 = ?", + bind: int(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint", + field: "c7", + where: "c7 = ?", + bind: uint(7), + expectValue: uint16(7), + }, + + //uint + { + name: "uint_int8", + field: "c8", + where: "c8 = ?", + bind: int8(8), + expectValue: uint32(8), + }, + { + name: "uint_iny16", + field: "c8", + where: "c8 = ?", + bind: int16(8), + expectValue: uint32(8), + }, + { + name: "uint_int32", + field: "c8", + where: "c8 = ?", + bind: int32(8), + expectValue: uint32(8), + }, + { + name: "uint_int64", + field: "c8", + where: "c8 = ?", + bind: int64(8), + expectValue: uint32(8), + }, + { + name: "uint_uint8", + field: "c8", + where: "c8 = ?", + bind: uint8(8), + expectValue: uint32(8), + }, + { + name: "uint_uint16", + field: "c8", + where: "c8 = ?", + bind: uint16(8), + expectValue: uint32(8), + }, + { + name: "uint_uint32", + field: "c8", + where: "c8 = ?", + bind: uint32(8), + expectValue: uint32(8), + }, + { + name: "uint_uint64", + field: "c8", + where: "c8 = ?", + bind: uint64(8), + expectValue: uint32(8), + }, + { + name: "uint_float32", + field: "c8", + where: "c8 = ?", + bind: float32(8), + expectValue: uint32(8), + }, + { + name: "uint_float64", + field: "c8", + where: "c8 = ?", + bind: float64(8), + expectValue: uint32(8), + }, + { + name: "uint_int", + field: "c8", + where: "c8 = ?", + bind: int(8), + expectValue: uint32(8), + }, + { + name: "uint_uint", + field: "c8", + where: "c8 = ?", + bind: uint(8), + expectValue: uint32(8), + }, + + //ubigint + { + name: "ubigint_int8", + field: "c9", + where: "c9 = ?", + bind: int8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_iny16", + field: "c9", + where: "c9 = ?", + bind: int16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int32", + field: "c9", + where: "c9 = ?", + bind: int32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int64", + field: "c9", + where: "c9 = ?", + bind: int64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint8", + field: "c9", + where: "c9 = ?", + bind: uint8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint16", + field: "c9", + where: "c9 = ?", + bind: uint16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint32", + field: "c9", + where: "c9 = ?", + bind: uint32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint64", + field: "c9", + where: "c9 = ?", + bind: uint64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float32", + field: "c9", + where: "c9 = ?", + bind: float32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float64", + field: "c9", + where: "c9 = ?", + bind: float64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int", + field: "c9", + where: "c9 = ?", + bind: int(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint", + field: "c9", + where: "c9 = ?", + bind: uint(9), + expectValue: uint64(9), + }, + + //float + { + name: "float_int8", + field: "c10", + where: "c10 = ?", + bind: int8(10), + expectValue: float32(10), + }, + { + name: "float_iny16", + field: "c10", + where: "c10 = ?", + bind: int16(10), + expectValue: float32(10), + }, + { + name: "float_int32", + field: "c10", + where: "c10 = ?", + bind: int32(10), + expectValue: float32(10), + }, + { + name: "float_int64", + field: "c10", + where: "c10 = ?", + bind: int64(10), + expectValue: float32(10), + }, + { + name: "float_uint8", + field: "c10", + where: "c10 = ?", + bind: uint8(10), + expectValue: float32(10), + }, + { + name: "float_uint16", + field: "c10", + where: "c10 = ?", + bind: uint16(10), + expectValue: float32(10), + }, + { + name: "float_uint32", + field: "c10", + where: "c10 = ?", + bind: uint32(10), + expectValue: float32(10), + }, + { + name: "float_uint64", + field: "c10", + where: "c10 = ?", + bind: uint64(10), + expectValue: float32(10), + }, + { + name: "float_float32", + field: "c10", + where: "c10 = ?", + bind: float32(10), + expectValue: float32(10), + }, + { + name: "float_float64", + field: "c10", + where: "c10 = ?", + bind: float64(10), + expectValue: float32(10), + }, + { + name: "float_int", + field: "c10", + where: "c10 = ?", + bind: int(10), + expectValue: float32(10), + }, + { + name: "float_uint", + field: "c10", + where: "c10 = ?", + bind: uint(10), + expectValue: float32(10), + }, + + //double + { + name: "double_int8", + field: "c11", + where: "c11 = ?", + bind: int8(11), + expectValue: float64(11), + }, + { + name: "double_iny16", + field: "c11", + where: "c11 = ?", + bind: int16(11), + expectValue: float64(11), + }, + { + name: "double_int32", + field: "c11", + where: "c11 = ?", + bind: int32(11), + expectValue: float64(11), + }, + { + name: "double_int64", + field: "c11", + where: "c11 = ?", + bind: int64(11), + expectValue: float64(11), + }, + { + name: "double_uint8", + field: "c11", + where: "c11 = ?", + bind: uint8(11), + expectValue: float64(11), + }, + { + name: "double_uint16", + field: "c11", + where: "c11 = ?", + bind: uint16(11), + expectValue: float64(11), + }, + { + name: "double_uint32", + field: "c11", + where: "c11 = ?", + bind: uint32(11), + expectValue: float64(11), + }, + { + name: "double_uint64", + field: "c11", + where: "c11 = ?", + bind: uint64(11), + expectValue: float64(11), + }, + { + name: "double_float32", + field: "c11", + where: "c11 = ?", + bind: float32(11), + expectValue: float64(11), + }, + { + name: "double_float64", + field: "c11", + where: "c11 = ?", + bind: float64(11), + expectValue: float64(11), + }, + { + name: "double_int", + field: "c11", + where: "c11 = ?", + bind: int(11), + expectValue: float64(11), + }, + { + name: "double_uint", + field: "c11", + where: "c11 = ?", + bind: uint(11), + expectValue: float64(11), + }, + + // binary + { + name: "binary_string", + field: "c12", + where: "c12 = ?", + bind: "binary", + expectValue: "binary", + }, + { + name: "binary_bytes", + field: "c12", + where: "c12 = ?", + bind: []byte("binary"), + expectValue: "binary", + }, + { + name: "binary_string_like", + field: "c12", + where: "c12 like ?", + bind: "bin%", + expectValue: "binary", + }, + + // nchar + { + name: "nchar_string", + field: "c13", + where: "c13 = ?", + bind: "nchar", + expectValue: "nchar", + }, + { + name: "nchar_bytes", + field: "c13", + where: "c13 = ?", + bind: []byte("nchar"), + expectValue: "nchar", + }, + { + name: "nchar_string", + field: "c13", + where: "c13 like ?", + bind: "nch%", + expectValue: "nchar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) + + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + rows, err := stmt.Query(tt.bind) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + tts, err := rows.ColumnTypes() + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + var data []driver.Value + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if tt.expectNoValue { + if len(data) > 0 { + t.Errorf("expect no value got %#v", data) + return + } + return + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} diff --git a/wrapper/notify_test.go b/wrapper/notify_test.go index fa8cd06..468dea1 100644 --- a/wrapper/notify_test.go +++ b/wrapper/notify_test.go @@ -2,7 +2,6 @@ package wrapper import ( "context" - "fmt" "testing" "time" @@ -26,6 +25,7 @@ func TestNotify(t *testing.T) { exec(conn, "drop user t_notify") err = exec(conn, "create user t_notify pass 'notify'") assert.NoError(t, err) + conn2, err := TaosConnect("", "t_notify", "notify", "", 0) if err != nil { t.Error(err) @@ -40,6 +40,22 @@ func TestNotify(t *testing.T) { errStr := TaosErrorStr(nil) t.Error(errCode, errStr) } + notifyWhitelist := make(chan int64, 1) + handlerWhiteList := cgo.NewHandle(notifyWhitelist) + errCode = TaosSetNotifyCB(conn2, handlerWhiteList, common.TAOS_NOTIFY_WHITELIST_VER) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + + notifyDropUser := make(chan struct{}, 1) + handlerDropUser := cgo.NewHandle(notifyDropUser) + errCode = TaosSetNotifyCB(conn2, handlerDropUser, common.TAOS_NOTIFY_USER_DROPPED) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + err = exec(conn, "alter user t_notify pass 'test'") assert.NoError(t, err) timeout, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -47,51 +63,36 @@ func TestNotify(t *testing.T) { now := time.Now() select { case version := <-notify: - fmt.Println(time.Now().Sub(now)) - t.Log(version) + t.Log(time.Now().Sub(now)) + t.Log("password changed", version) case <-timeout.Done(): t.Error("wait for notify callback timeout") } - { - notify := make(chan int64, 1) - handler := cgo.NewHandle(notify) - errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_WHITELIST_VER) - if errCode != 0 { - errStr := TaosErrorStr(nil) - t.Error(errCode, errStr) - } - err = exec(conn, "ALTER USER t_notify ADD HOST '192.168.1.98/0','192.168.1.98/32'") - assert.NoError(t, err) - timeout, cancel = context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - now := time.Now() - select { - case version := <-notify: - fmt.Println(time.Now().Sub(now)) - t.Log(version) - case <-timeout.Done(): - t.Error("wait for notify callback timeout") - } + + err = exec(conn, "ALTER USER t_notify ADD HOST '192.168.1.98/0','192.168.1.98/32'") + assert.NoError(t, err) + timeoutWhiteList, cancelWhitelist := context.WithTimeout(context.Background(), time.Second*5) + defer cancelWhitelist() + now = time.Now() + select { + case version := <-notifyWhitelist: + t.Log(time.Now().Sub(now)) + t.Log("whitelist changed", version) + case <-timeoutWhiteList.Done(): + t.Error("wait for notifyWhitelist callback timeout") } - { - notify := make(chan struct{}, 1) - handler := cgo.NewHandle(notify) - errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_USER_DROPPED) - if errCode != 0 { - errStr := TaosErrorStr(nil) - t.Error(errCode, errStr) - } - err = exec(conn, "drop USER t_notify") - assert.NoError(t, err) - timeout, cancel = context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - now := time.Now() - select { - case _ = <-notify: - fmt.Println(time.Now().Sub(now)) - t.Log("user dropped") - case <-timeout.Done(): - t.Error("wait for notify callback timeout") - } + + err = exec(conn, "drop USER t_notify") + assert.NoError(t, err) + timeoutDropUser, cancelDropUser := context.WithTimeout(context.Background(), time.Second*5) + defer cancelDropUser() + now = time.Now() + select { + case _ = <-notifyDropUser: + t.Log(time.Now().Sub(now)) + t.Log("user dropped") + case <-timeoutDropUser.Done(): + t.Error("wait for notifyDropUser callback timeoutDropUser") } + } diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index d924645..339a6ec 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -702,7 +702,6 @@ func StmtQuery(t *testing.T, conn unsafe.Pointer, sql string, params *param.Para if errCode != int(taosError.SUCCESS) { errStr := TaosErrorStr(res) err := taosError.NewError(code, errStr) - TaosFreeResult(res) return nil, err } if blockSize == 0 { @@ -711,7 +710,6 @@ func StmtQuery(t *testing.T, conn unsafe.Pointer, sql string, params *param.Para d := parser.ReadBlock(block, blockSize, rowsHeader.ColTypes, precision) data = append(data, d...) } - TaosFreeResult(res) return data, nil } diff --git a/wrapper/taosc.go b/wrapper/taosc.go index 2e952a2..5fea8eb 100644 --- a/wrapper/taosc.go +++ b/wrapper/taosc.go @@ -82,8 +82,8 @@ func TaosQuery(taosConnect unsafe.Pointer, sql string) unsafe.Pointer { return unsafe.Pointer(C.taos_query(taosConnect, cSql)) } -// TasoQueryWithReqID TAOS_RES *taos_query_with_reqid(TAOS *taos, const char *sql, int64_t reqID); -func TasoQueryWithReqID(taosConn unsafe.Pointer, sql string, reqID int64) unsafe.Pointer { +// TaosQueryWithReqID TAOS_RES *taos_query_with_reqid(TAOS *taos, const char *sql, int64_t reqID); +func TaosQueryWithReqID(taosConn unsafe.Pointer, sql string, reqID int64) unsafe.Pointer { cSql := C.CString(sql) defer C.free(unsafe.Pointer(cSql)) return unsafe.Pointer(C.taos_query_with_reqid(taosConn, cSql, (C.int64_t)(reqID))) @@ -269,11 +269,11 @@ func TaosSetConnMode(conn unsafe.Pointer, mode int, value int) int { // TaosGetCurrentDB DLL_EXPORT int taos_get_current_db(TAOS *taos, char *database, int len, int *required) func TaosGetCurrentDB(conn unsafe.Pointer) (db string, err error) { - cDb := C.CString(db) + cDb := (*C.char)(C.malloc(195)) defer C.free(unsafe.Pointer(cDb)) var required int - code := C.taos_get_current_db(conn, cDb, C.int(193), (*C.int)(unsafe.Pointer(&required))) + code := C.taos_get_current_db(conn, cDb, C.int(195), (*C.int)(unsafe.Pointer(&required))) if code != 0 { err = errors.NewError(int(code), TaosErrorStr(nil)) } diff --git a/wrapper/tmq_test.go b/wrapper/tmq_test.go index b4992cf..4e201fa 100644 --- a/wrapper/tmq_test.go +++ b/wrapper/tmq_test.go @@ -1074,7 +1074,7 @@ func TestTMQModify(t *testing.T) { h2 := cgo.NewHandle(c2) targetConn, err := TaosConnect("", "root", "taosdata", "tmq_test_db_modify_target", 0) assert.NoError(t, err) - defer TaosFreeResult(targetConn) + defer TaosClose(targetConn) result = TaosQuery(conn, "create table stb (ts timestamp,"+ "c1 bool,"+ "c2 tinyint,"+ @@ -1170,70 +1170,41 @@ func TestTMQModify(t *testing.T) { } d, err := query(targetConn, "describe stb") assert.NoError(t, err) - if len(d[0]) == 4 { - assert.Equal(t, [][]driver.Value{ - {"ts", "TIMESTAMP", int32(8), ""}, - {"c1", "BOOL", int32(1), ""}, - {"c2", "TINYINT", int32(1), ""}, - {"c3", "SMALLINT", int32(2), ""}, - {"c4", "INT", int32(4), ""}, - {"c5", "BIGINT", int32(8), ""}, - {"c6", "TINYINT UNSIGNED", int32(1), ""}, - {"c7", "SMALLINT UNSIGNED", int32(2), ""}, - {"c8", "INT UNSIGNED", int32(4), ""}, - {"c9", "BIGINT UNSIGNED", int32(8), ""}, - {"c10", "FLOAT", int32(4), ""}, - {"c11", "DOUBLE", int32(8), ""}, - {"c12", "VARCHAR", int32(20), ""}, - {"c13", "NCHAR", int32(20), ""}, - {"tts", "TIMESTAMP", int32(8), "TAG"}, - {"tc1", "BOOL", int32(1), "TAG"}, - {"tc2", "TINYINT", int32(1), "TAG"}, - {"tc3", "SMALLINT", int32(2), "TAG"}, - {"tc4", "INT", int32(4), "TAG"}, - {"tc5", "BIGINT", int32(8), "TAG"}, - {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, - {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, - {"tc8", "INT UNSIGNED", int32(4), "TAG"}, - {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, - {"tc10", "FLOAT", int32(4), "TAG"}, - {"tc11", "DOUBLE", int32(8), "TAG"}, - {"tc12", "VARCHAR", int32(20), "TAG"}, - {"tc13", "NCHAR", int32(20), "TAG"}, - }, d) - } else { - assert.Equal(t, [][]driver.Value{ - {"ts", "TIMESTAMP", int32(8), "", ""}, - {"c1", "BOOL", int32(1), "", ""}, - {"c2", "TINYINT", int32(1), "", ""}, - {"c3", "SMALLINT", int32(2), "", ""}, - {"c4", "INT", int32(4), "", ""}, - {"c5", "BIGINT", int32(8), "", ""}, - {"c6", "TINYINT UNSIGNED", int32(1), "", ""}, - {"c7", "SMALLINT UNSIGNED", int32(2), "", ""}, - {"c8", "INT UNSIGNED", int32(4), "", ""}, - {"c9", "BIGINT UNSIGNED", int32(8), "", ""}, - {"c10", "FLOAT", int32(4), "", ""}, - {"c11", "DOUBLE", int32(8), "", ""}, - {"c12", "VARCHAR", int32(20), "", ""}, - {"c13", "NCHAR", int32(20), "", ""}, - {"tts", "TIMESTAMP", int32(8), "TAG", ""}, - {"tc1", "BOOL", int32(1), "TAG", ""}, - {"tc2", "TINYINT", int32(1), "TAG", ""}, - {"tc3", "SMALLINT", int32(2), "TAG", ""}, - {"tc4", "INT", int32(4), "TAG", ""}, - {"tc5", "BIGINT", int32(8), "TAG", ""}, - {"tc6", "TINYINT UNSIGNED", int32(1), "TAG", ""}, - {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG", ""}, - {"tc8", "INT UNSIGNED", int32(4), "TAG", ""}, - {"tc9", "BIGINT UNSIGNED", int32(8), "TAG", ""}, - {"tc10", "FLOAT", int32(4), "TAG", ""}, - {"tc11", "DOUBLE", int32(8), "TAG", ""}, - {"tc12", "VARCHAR", int32(20), "TAG", ""}, - {"tc13", "NCHAR", int32(20), "TAG", ""}, - }, d) + expect := [][]driver.Value{ + {"ts", "TIMESTAMP", int32(8), ""}, + {"c1", "BOOL", int32(1), ""}, + {"c2", "TINYINT", int32(1), ""}, + {"c3", "SMALLINT", int32(2), ""}, + {"c4", "INT", int32(4), ""}, + {"c5", "BIGINT", int32(8), ""}, + {"c6", "TINYINT UNSIGNED", int32(1), ""}, + {"c7", "SMALLINT UNSIGNED", int32(2), ""}, + {"c8", "INT UNSIGNED", int32(4), ""}, + {"c9", "BIGINT UNSIGNED", int32(8), ""}, + {"c10", "FLOAT", int32(4), ""}, + {"c11", "DOUBLE", int32(8), ""}, + {"c12", "VARCHAR", int32(20), ""}, + {"c13", "NCHAR", int32(20), ""}, + {"tts", "TIMESTAMP", int32(8), "TAG"}, + {"tc1", "BOOL", int32(1), "TAG"}, + {"tc2", "TINYINT", int32(1), "TAG"}, + {"tc3", "SMALLINT", int32(2), "TAG"}, + {"tc4", "INT", int32(4), "TAG"}, + {"tc5", "BIGINT", int32(8), "TAG"}, + {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, + {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, + {"tc8", "INT UNSIGNED", int32(4), "TAG"}, + {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, + {"tc10", "FLOAT", int32(4), "TAG"}, + {"tc11", "DOUBLE", int32(8), "TAG"}, + {"tc12", "VARCHAR", int32(20), "TAG"}, + {"tc13", "NCHAR", int32(20), "TAG"}, + } + for rowIndex, values := range d { + for i := 0; i < 4; i++ { + assert.Equal(t, expect[rowIndex][i], values[i]) + } } - }) TMQUnsubscribe(tmq) diff --git a/ws/client/conn.go b/ws/client/conn.go index 37a5dd2..76506fc 100644 --- a/ws/client/conn.go +++ b/ws/client/conn.go @@ -3,6 +3,7 @@ package client import ( "bytes" "encoding/json" + "errors" "sync" "sync/atomic" "time" @@ -33,7 +34,7 @@ type EnvelopePool struct { func (ep *EnvelopePool) Get() *Envelope { epv := ep.p.Get() if epv == nil { - return &Envelope{Msg: new(bytes.Buffer)} + return &Envelope{Msg: new(bytes.Buffer), ErrorChan: make(chan error, 1)} } return epv.(*Envelope) } @@ -44,14 +45,24 @@ func (ep *EnvelopePool) Put(epv *Envelope) { } type Envelope struct { - Type int - Msg *bytes.Buffer + Type int + Msg *bytes.Buffer + ErrorChan chan error } func (e *Envelope) Reset() { - e.Msg.Reset() + if e.Msg.Cap() > 64*1024 { + e.Msg = new(bytes.Buffer) + } else { + e.Msg.Reset() + } + if len(e.ErrorChan) > 0 { + e.ErrorChan = make(chan error, 1) + } } +var ClosedError = errors.New("websocket closed") + type Client struct { conn *websocket.Conn status uint32 @@ -63,9 +74,10 @@ type Client struct { TextMessageHandler func(message []byte) BinaryMessageHandler func(message []byte) ErrorHandler func(err error) - SendMessageHandler func(envelope *Envelope) - once sync.Once - errHandlerOnce sync.Once + //SendMessageHandler func(envelope *Envelope) + once sync.Once + errHandlerOnce sync.Once + err error } func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { @@ -80,9 +92,9 @@ func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { TextMessageHandler: func(message []byte) {}, BinaryMessageHandler: func(message []byte) {}, ErrorHandler: func(err error) {}, - SendMessageHandler: func(envelope *Envelope) { - GlobalEnvelopePool.Put(envelope) - }, + //SendMessageHandler: func(envelope *Envelope) { + // GlobalEnvelopePool.Put(envelope) + //}, } } @@ -117,41 +129,61 @@ func (c *Client) WritePump() { defer func() { ticker.Stop() }() + for { select { case message, ok := <-c.sendChan: if !ok { - return + if message == nil { + return + } + message.ErrorChan <- ClosedError + continue } c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) err := c.conn.WriteMessage(message.Type, message.Msg.Bytes()) if err != nil { + message.ErrorChan <- err c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } - c.SendMessageHandler(message) + message.ErrorChan <- nil case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } } } } -func (c *Client) Send(envelope *Envelope) { +func (c *Client) Send(envelope *Envelope) error { if !c.IsRunning() { - return + return ClosedError } + var err error defer func() { // maybe closed if recover() != nil { - + err = ClosedError return } }() c.sendChan <- envelope + return err } func (c *Client) GetEnvelope() *Envelope { @@ -168,8 +200,8 @@ func (c *Client) IsRunning() bool { func (c *Client) Close() { c.once.Do(func() { - close(c.sendChan) atomic.StoreUint32(&c.status, StatusStop) + close(c.sendChan) if c.conn != nil { c.conn.Close() } diff --git a/ws/client/conn_test.go b/ws/client/conn_test.go new file mode 100644 index 0000000..c7f9a7c --- /dev/null +++ b/ws/client/conn_test.go @@ -0,0 +1,98 @@ +package client + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" +) + +func TestEnvelopePool(t *testing.T) { + pool := &EnvelopePool{} + + // Test Get method + env := pool.Get() + assert.NotNil(t, env) + assert.NotNil(t, env.Msg) + + // Test Put method + env.Msg.WriteString("test") + pool.Put(env) + + // Test if the envelope is reset after put + env = pool.Get() + assert.Equal(t, 0, env.Msg.Len()) +} + +func TestEnvelope_Reset(t *testing.T) { + env := &Envelope{ + Type: 1, + Msg: bytes.NewBufferString("test"), + } + + env.Reset() + + assert.Equal(t, 0, env.Msg.Len()) +} + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +func wsEchoServer(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + for { + messageType, message, err := conn.ReadMessage() + if err != nil { + return + } + + if err := conn.WriteMessage(messageType, message); err != nil { + return + } + } +} + +func TestClient(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(wsEchoServer)) + defer s.Close() + t.Log(s.URL) + ep := "ws" + strings.TrimPrefix(s.URL, "http") + ws, _, err := websocket.DefaultDialer.Dial(ep, nil) + assert.NoError(t, err) + c := NewClient(ws, 1) + gotMessage := make(chan struct{}) + c.TextMessageHandler = func(message []byte) { + assert.Equal(t, "test", string(message)) + gotMessage <- struct{}{} + } + running := c.IsRunning() + assert.True(t, running) + defer c.Close() + go c.ReadPump() + go c.WritePump() + env := c.GetEnvelope() + env.Type = websocket.TextMessage + env.Msg.WriteString("test") + c.Send(env) + env = c.GetEnvelope() + c.PutEnvelope(env) + timeout := time.NewTimer(time.Second * 3) + select { + case <-gotMessage: + t.Log("got message") + case <-timeout.C: + t.Error("timeout") + } +} diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go index 58f65b0..7599984 100644 --- a/ws/schemaless/config.go +++ b/ws/schemaless/config.go @@ -10,18 +10,22 @@ const ( ) type Config struct { - url string - chanLength uint - user string - password string - db string - readTimeout time.Duration - writeTimeout time.Duration - errorHandler func(error) + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) + enableCompression bool + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { - c := Config{url: url, chanLength: chanLength} + c := Config{url: url, chanLength: chanLength, reconnectRetryCount: 3, reconnectIntervalMs: 2000} for _, opt := range opts { opt(&c) } @@ -64,3 +68,27 @@ func SetErrorHandler(errorHandler func(error)) func(*Config) { c.errorHandler = errorHandler } } + +func SetEnableCompression(enableCompression bool) func(*Config) { + return func(c *Config) { + c.enableCompression = enableCompression + } +} + +func SetAutoReconnect(reconnect bool) func(*Config) { + return func(c *Config) { + c.autoReconnect = reconnect + } +} + +func SetReconnectIntervalMs(reconnectIntervalMs int) func(*Config) { + return func(c *Config) { + c.reconnectIntervalMs = reconnectIntervalMs + } +} + +func SetReconnectRetryCount(reconnectRetryCount int) func(*Config) { + return func(c *Config) { + c.reconnectRetryCount = reconnectRetryCount + } +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index db44e09..f22c7a6 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "sync" "time" @@ -23,17 +24,23 @@ const ( ) type Schemaless struct { - client *client.Client - sendList *list.List - url string - user string - password string - db string - readTimeout time.Duration - lock sync.Mutex - once sync.Once - closeChan chan struct{} - errorHandler func(error) + client *client.Client + sendList *list.List + url string + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + lock sync.Mutex + once sync.Once + closeChan chan struct{} + errorHandler func(error) + dialer *websocket.Dialer + chanLength uint + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewSchemaless(config *Config) (*Schemaless, error) { @@ -44,23 +51,31 @@ func NewSchemaless(config *Config) (*Schemaless, error) { if wsUrl.Scheme != "ws" && wsUrl.Scheme != "wss" { return nil, errors.New("config url scheme error") } - if len(wsUrl.Path) == 0 || wsUrl.Path != "/rest/schemaless" { - wsUrl.Path = "/rest/schemaless" - } - ws, _, err := common.DefaultDialer.Dial(wsUrl.String(), nil) + wsUrl.Path = "/ws" + dialer := common.DefaultDialer + dialer.EnableCompression = config.enableCompression + conn, _, err := dialer.Dial(wsUrl.String(), nil) if err != nil { return nil, fmt.Errorf("dial ws error: %s", err) } - + conn.EnableWriteCompression(config.enableCompression) s := Schemaless{ - client: client.NewClient(ws, config.chanLength), + client: client.NewClient(conn, config.chanLength), sendList: list.New(), - url: config.url, + url: wsUrl.String(), user: config.user, password: config.password, db: config.db, closeChan: make(chan struct{}), errorHandler: config.errorHandler, + dialer: &dialer, + chanLength: config.chanLength, + } + + if config.autoReconnect { + s.autoReconnect = true + s.reconnectIntervalMs = config.reconnectIntervalMs + s.reconnectRetryCount = config.reconnectRetryCount } if config.readTimeout > 0 { @@ -68,21 +83,59 @@ func NewSchemaless(config *Config) (*Schemaless, error) { } if config.writeTimeout > 0 { - s.client.WriteWait = config.writeTimeout + s.writeTimeout = config.writeTimeout } - s.client.ErrorHandler = s.handleError - s.client.TextMessageHandler = s.handleTextMessage - - go s.client.ReadPump() - go s.client.WritePump() - if err = s.connect(); err != nil { + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { return nil, fmt.Errorf("connect ws error: %s", err) } + s.initClient(s.client) return &s, nil } +func (s *Schemaless) initClient(c *client.Client) { + if s.writeTimeout > 0 { + c.WriteWait = s.writeTimeout + } + c.ErrorHandler = s.handleError + c.TextMessageHandler = s.handleTextMessage + + go c.ReadPump() + go c.WritePump() +} + +func (s *Schemaless) reconnect() error { + reconnected := false + for i := 0; i < s.reconnectRetryCount; i++ { + time.Sleep(time.Duration(s.reconnectIntervalMs) * time.Millisecond) + conn, _, err := s.dialer.Dial(s.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(s.dialer.EnableCompression) + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { + conn.Close() + continue + } + if s.client != nil { + s.client.Close() + } + c := client.NewClient(conn, s.chanLength) + s.initClient(c) + s.client = c + reconnected = true + break + } + if !reconnected { + if s.client != nil { + s.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl int, reqID int64) error { if reqID == 0 { reqID = common.GetReqID() @@ -101,15 +154,30 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in return err } action := &client.WSAction{Action: insertAction, Args: args} - envelope := s.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.client.PutEnvelope(envelope) return err } respBytes, err := s.sendText(uint64(reqID), envelope) if err != nil { - return err + if !s.autoReconnect { + return err + } + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = s.reconnect() + if err != nil { + return err + } + respBytes, err = s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } + } else { + return err + } } var resp schemalessResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -132,13 +200,16 @@ func (s *Schemaless) Close() { }) } -func (s *Schemaless) connect() error { - reqID := uint64(common.GetReqID()) +var ( + ConnectTimeoutErr = errors.New("schemaless connect timeout") +) + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &wsConnectReq{ - ReqID: reqID, - User: s.user, - Password: s.password, - DB: s.db, + ReqID: 0, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -148,14 +219,29 @@ func (s *Schemaless) connect() error { Action: connAction, Args: args, } - envelope := s.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + connectAction, err := client.JsonI.Marshal(action) if err != nil { - s.client.PutEnvelope(envelope) return err } - - respBytes, err := s.sendText(reqID, envelope) + ws.SetWriteDeadline(time.Now().Add(writeTimeout)) + err = ws.WriteMessage(websocket.TextMessage, connectAction) + if err != nil { + return err + } + done := make(chan struct{}) + ctx, cancel := context.WithTimeout(context.Background(), readTimeout) + var respBytes []byte + go func() { + _, respBytes, err = ws.ReadMessage() + close(done) + }() + select { + case <-done: + cancel() + case <-ctx.Done(): + cancel() + return ConnectTimeoutErr + } if err != nil { return err } @@ -181,7 +267,20 @@ func (s *Schemaless) send(reqID uint64, envelope *client.Envelope) ([]byte, erro channel: make(chan []byte, 1), } element := s.addMessageOutChan(channel) - s.client.Send(envelope) + err := s.client.Send(envelope) + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), s.readTimeout) defer cancel() select { @@ -258,5 +357,4 @@ func (s *Schemaless) handleError(err error) { if s.errorHandler != nil { s.errorHandler(err) } - s.Close() } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index bfa3199..dc9caa6 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -1,14 +1,20 @@ package schemaless import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -62,20 +68,21 @@ func TestSchemaless_Insert(t *testing.T) { } defer func() { _ = after() }() - s, err := NewSchemaless(NewConfig("ws://localhost:6041/rest/schemaless", 1, + s, err := NewSchemaless(NewConfig("ws://localhost:6041", 1, SetDb("test_schemaless_ws"), SetReadTimeout(10*time.Second), SetWriteTimeout(10*time.Second), SetUser("root"), SetPassword("taosdata"), + SetEnableCompression(true), SetErrorHandler(func(err error) { - t.Fatal(err) + t.Log(err) }), )) if err != nil { t.Fatal(err) } - //defer s.Close() + defer s.Close() for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -132,3 +139,103 @@ func before() error { func after() error { return doRequest("drop database test_schemaless_ws") } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port, "--logLevel", "debug") +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 30; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil + time.Sleep(time.Second) +} + +func TestSchemalessReconnect(t *testing.T) { + port := "36041" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + err = doRequest("drop database if exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + err = doRequest("create database if not exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1, + SetDb("test_schemaless_reconnect"), + SetReadTimeout(3*time.Second), + SetWriteTimeout(3*time.Second), + SetUser("root"), + SetPassword("taosdata"), + //SetEnableCompression(true), + SetErrorHandler(func(err error) { + t.Log(err) + }), + SetAutoReconnect(true), + SetReconnectIntervalMs(2000), + SetReconnectRetryCount(3), + )) + if err != nil { + t.Fatal(err) + } + stopTaosadapter(cmd) + time.Sleep(time.Second * 3) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 10) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } + }() + data := "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000" + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) +} diff --git a/ws/stmt/config.go b/ws/stmt/config.go index 332ac55..3b533cc 100644 --- a/ws/stmt/config.go +++ b/ws/stmt/config.go @@ -6,21 +6,27 @@ import ( ) type Config struct { - Url string - ChanLength uint - MessageTimeout time.Duration - WriteWait time.Duration - ErrorHandler func(connector *Connector, err error) - CloseHandler func() - User string - Password string - DB string + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string + EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func NewConfig(url string, chanLength uint) *Config { return &Config{ - Url: url, - ChanLength: chanLength, + Url: url, + ChanLength: chanLength, + ReconnectRetryCount: 3, + ReconnectIntervalMs: 2000, } } func (c *Config) SetConnectUser(user string) error { @@ -60,3 +66,19 @@ func (c *Config) SetErrorHandler(f func(connector *Connector, err error)) { func (c *Config) SetCloseHandler(f func()) { c.CloseHandler = f } + +func (c *Config) SetEnableCompression(enableCompression bool) { + c.EnableCompression = enableCompression +} + +func (c *Config) SetAutoReconnect(reconnect bool) { + c.AutoReconnect = reconnect +} + +func (c *Config) SetReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *Config) SetReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 01ba361..a08cd2e 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -3,8 +3,11 @@ package stmt import ( "container/list" "context" + "encoding/binary" "errors" "fmt" + "net" + "net/url" "sync" "sync/atomic" "time" @@ -17,17 +20,26 @@ import ( ) type Connector struct { - client *client.Client - requestID uint64 - listLock sync.RWMutex - sendChanList *list.List - writeTimeout time.Duration - readTimeout time.Duration - config *Config - closeOnce sync.Once - closeChan chan struct{} - customErrorHandler func(*Connector, error) - customCloseHandler func() + client *client.Client + requestID uint64 + listLock sync.RWMutex + sendChanList *list.List + writeTimeout time.Duration + readTimeout time.Duration + config *Config + closeOnce sync.Once + closeChan chan struct{} + customErrorHandler func(*Connector, error) + customCloseHandler func() + url string + chanLength uint + dialer *websocket.Dialer + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + user string + password string + db string } var ( @@ -44,10 +56,18 @@ func NewConnector(config *Config) (*Connector, error) { if config.WriteWait > 0 { writeTimeout = config.WriteWait } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + u, err := url.Parse(config.Url) if err != nil { return nil, err } + u.Path = "/ws" + ws, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return nil, err + } + ws.EnableWriteCompression(config.EnableCompression) defer func() { if connector == nil { ws.Close() @@ -56,15 +76,58 @@ func NewConnector(config *Config) (*Connector, error) { if config.MessageTimeout <= 0 { config.MessageTimeout = common.DefaultMessageTimeout } + err = connect(ws, config.User, config.Password, config.DB, writeTimeout, readTimeout) + if err != nil { + return nil, err + } + wsClient := client.NewClient(ws, config.ChanLength) + connector = &Connector{ + client: wsClient, + requestID: 0, + listLock: sync.RWMutex{}, + sendChanList: list.New(), + writeTimeout: writeTimeout, + readTimeout: readTimeout, + config: config, + closeOnce: sync.Once{}, + closeChan: make(chan struct{}), + customErrorHandler: config.ErrorHandler, + customCloseHandler: config.CloseHandler, + url: u.String(), + dialer: &dialer, + chanLength: config.ChanLength, + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + user: config.User, + password: config.Password, + db: config.DB, + } + connector.initClient(connector.client) + return connector, nil +} + +func (c *Connector) initClient(client *client.Client) { + if c.writeTimeout > 0 { + client.WriteWait = c.writeTimeout + } + client.TextMessageHandler = c.handleTextMessage + client.BinaryMessageHandler = c.handleBinaryMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &ConnectReq{ ReqID: 0, - User: config.User, - Password: config.Password, - DB: config.DB, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: STMTConnect, @@ -72,12 +135,12 @@ func NewConnector(config *Config) (*Connector, error) { } connectAction, err := client.JsonI.Marshal(action) if err != nil { - return nil, err + return err } ws.SetWriteDeadline(time.Now().Add(writeTimeout)) err = ws.WriteMessage(websocket.TextMessage, connectAction) if err != nil { - return nil, err + return err } done := make(chan struct{}) ctx, cancel := context.WithTimeout(context.Background(), readTimeout) @@ -91,40 +154,20 @@ func NewConnector(config *Config) (*Connector, error) { cancel() case <-ctx.Done(): cancel() - return nil, ConnectTimeoutErr + return ConnectTimeoutErr } if err != nil { - return nil, err + return err } var resp ConnectResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - wsClient := client.NewClient(ws, config.ChanLength) - wsClient.WriteWait = writeTimeout - connector = &Connector{ - client: wsClient, - requestID: 0, - listLock: sync.RWMutex{}, - sendChanList: list.New(), - writeTimeout: writeTimeout, - readTimeout: readTimeout, - config: config, - closeOnce: sync.Once{}, - closeChan: make(chan struct{}), - customErrorHandler: config.ErrorHandler, - customCloseHandler: config.CloseHandler, + return taosErrors.NewError(resp.Code, resp.Message) } - - wsClient.TextMessageHandler = connector.handleTextMessage - wsClient.ErrorHandler = connector.handleError - go wsClient.WritePump() - go wsClient.ReadPump() - return connector, nil + return nil } func (c *Connector) handleTextMessage(message []byte) { @@ -150,6 +193,17 @@ func (c *Connector) handleTextMessage(message []byte) { c.listLock.Unlock() } +func (c *Connector) handleBinaryMessage(message []byte) { + reqID := binary.LittleEndian.Uint64(message[8:16]) + c.listLock.Lock() + element := c.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + c.sendChanList.Remove(element) + } + c.listLock.Unlock() +} + type IndexedChan struct { index uint64 channel chan []byte @@ -169,7 +223,20 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error channel: make(chan []byte, 1), } element := c.addMessageOutChan(channel) - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.readTimeout) defer cancel() select { @@ -188,6 +255,7 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error func (c *Connector) sendTextWithoutResp(envelope *client.Envelope) { envelope.Type = websocket.TextMessage c.client.Send(envelope) + <-envelope.ErrorChan } func (c *Connector) findOutChanByID(index uint64) *list.Element { @@ -222,13 +290,45 @@ func (c *Connector) handleError(err error) { if c.customErrorHandler != nil { c.customErrorHandler(c, err) } - c.Close() + //c.Close() } func (c *Connector) generateReqID() uint64 { return atomic.AddUint64(&c.requestID, 1) } +func (c *Connector) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + err = connect(conn, c.user, c.password, c.db, c.writeTimeout, c.readTimeout) + if err != nil { + conn.Close() + continue + } + if c.client != nil { + c.client.Close() + } + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + c.client = cl + reconnected = true + break + } + if !reconnected { + if c.client != nil { + c.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (c *Connector) Init() (*Stmt, error) { reqID := c.generateReqID() req := &InitReq{ @@ -242,15 +342,30 @@ func (c *Connector) Init() (*Stmt, error) { Action: STMTInit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + if !c.autoReconnect { + return nil, err + } + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return nil, err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + } else { + return nil, err + } } var resp InitResp err = client.JsonI.Unmarshal(respBytes, &resp) diff --git a/ws/stmt/proto.go b/ws/stmt/proto.go index 2fed0ab..b5dc92d 100644 --- a/ws/stmt/proto.go +++ b/ws/stmt/proto.go @@ -15,6 +15,10 @@ const ( STMTAddBatch = "add_batch" STMTExec = "exec" STMTClose = "close" + STMTUseResult = "use_result" + WSFetch = "fetch" + WSFetchBlock = "fetch_block" + WSFreeResult = "free_result" ) type ConnectReq struct { @@ -134,3 +138,50 @@ type CloseReq struct { ReqID uint64 `json:"req_id"` StmtID uint64 `json:"stmt_id"` } + +type UseResultReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type UseResultResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +type WSFetchReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFetchResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + Completed bool `json:"completed"` + Lengths []int `json:"lengths"` + Rows int `json:"rows"` +} + +type WSFetchBlockReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFreeResultRequest struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} diff --git a/ws/stmt/rows.go b/ws/stmt/rows.go new file mode 100644 index 0000000..5247b55 --- /dev/null +++ b/ws/stmt/rows.go @@ -0,0 +1,172 @@ +package stmt + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "io" + "reflect" + "unsafe" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/common/pointer" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +type Rows struct { + buf *bytes.Buffer + blockPtr unsafe.Pointer + blockOffset int + blockSize int + resultID uint64 + block []byte + conn *Connector + client *client.Client + fieldsCount int + fieldsNames []string + fieldsTypes []uint8 + fieldsLengths []int64 + precision int +} + +func (rs *Rows) Columns() []string { + return rs.fieldsNames +} + +func (rs *Rows) ColumnTypeDatabaseTypeName(i int) string { + return common.TypeNameMap[int(rs.fieldsTypes[i])] +} + +func (rs *Rows) ColumnTypeLength(i int) (length int64, ok bool) { + return rs.fieldsLengths[i], ok +} + +func (rs *Rows) ColumnTypeScanType(i int) reflect.Type { + t, exist := common.ColumnTypeMap[int(rs.fieldsTypes[i])] + if !exist { + return common.UnknownType + } + return t +} + +func (rs *Rows) Close() error { + rs.blockPtr = nil + rs.block = nil + return rs.freeResult() +} + +func (rs *Rows) Next(dest []driver.Value) error { + if rs.blockPtr == nil || rs.blockOffset >= rs.blockSize { + err := rs.taosFetchBlock() + if err != nil { + return err + } + } + if rs.blockSize == 0 { + rs.blockPtr = nil + rs.block = nil + return io.EOF + } + parser.ReadRow(dest, rs.blockPtr, rs.blockSize, rs.blockOffset, rs.fieldsTypes, rs.precision) + rs.blockOffset += 1 + return nil +} + +func (rs *Rows) taosFetchBlock() error { + reqID := rs.conn.generateReqID() + req := &WSFetchReq{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := json.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFetch, + Args: args, + } + rs.buf.Reset() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + return err + } + respBytes, err := rs.conn.sendText(reqID, envelope) + if err != nil { + return err + } + var resp WSFetchResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + if resp.Completed { + rs.blockSize = 0 + return nil + } else { + rs.blockSize = resp.Rows + return rs.fetchBlock() + } +} + +func (rs *Rows) fetchBlock() error { + req := &WSFetchBlockReq{ + ReqID: rs.resultID, + ID: rs.resultID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFetchBlock, + Args: args, + } + rs.buf.Reset() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + return err + } + respBytes, err := rs.conn.sendText(rs.resultID, envelope) + if err != nil { + return err + } + rs.block = respBytes + rs.blockPtr = pointer.AddUintptr(unsafe.Pointer(&rs.block[0]), 16) + rs.blockOffset = 0 + return nil +} + +func (rs *Rows) freeResult() error { + reqID := rs.conn.generateReqID() + req := &WSFreeResultRequest{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFreeResult, + Args: args, + } + rs.buf.Reset() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + return err + } + rs.conn.sendTextWithoutResp(envelope) + return nil +} diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index e3c4c74..373b763 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -1,6 +1,7 @@ package stmt import ( + "bytes" "encoding/binary" "github.com/taosdata/driver-go/v3/common/param" @@ -30,10 +31,10 @@ func (s *Stmt) Prepare(sql string) error { Action: STMTPrepare, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -66,10 +67,10 @@ func (s *Stmt) SetTableName(name string) error { Action: STMTSetTableName, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -102,7 +103,8 @@ func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], SetTagsMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) @@ -131,13 +133,13 @@ func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) erro binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], BindMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) err = client.JsonI.NewEncoder(envelope.Msg).Encode(reqData) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendBinary(reqID, envelope) @@ -169,10 +171,10 @@ func (s *Stmt) AddBatch() error { Action: STMTAddBatch, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -204,10 +206,10 @@ func (s *Stmt) Exec() error { Action: STMTExec, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -230,6 +232,51 @@ func (s *Stmt) GetAffectedRows() int { return s.lastAffected } +func (s *Stmt) UseResult() (*Rows, error) { + reqID := s.connector.generateReqID() + req := &UseResultReq{ + ReqID: reqID, + StmtID: s.id, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: STMTUseResult, + Args: args, + } + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + return nil, err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp UseResultResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return &Rows{ + buf: &bytes.Buffer{}, + conn: s.connector, + client: s.connector.client, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + }, nil +} + func (s *Stmt) Close() error { reqID := s.connector.generateReqID() req := &CloseReq{ @@ -244,10 +291,10 @@ func (s *Stmt) Close() error { Action: STMTClose, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } s.connector.sendTextWithoutResp(envelope) diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 4cc1631..652766e 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -2,11 +2,16 @@ package stmt import ( "database/sql/driver" + "errors" "fmt" "io" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -18,12 +23,12 @@ import ( "github.com/taosdata/driver-go/v3/ws/client" ) -func prepareEnv() error { +func prepareEnv(db string) error { var err error steps := []string{ - "drop database if exists test_ws_stmt", - "create database test_ws_stmt", - "create table test_ws_stmt.all_json(ts timestamp," + + "drop database if exists " + db, + "create database " + db, + "create table " + db + ".all_json(ts timestamp," + "c1 bool," + "c2 tinyint," + "c3 smallint," + @@ -39,7 +44,7 @@ func prepareEnv() error { "c13 nchar(20)" + ")" + "tags(t json)", - "create table test_ws_stmt.all_all(" + + "create table " + db + ".all_all(" + "ts timestamp," + "c1 bool," + "c2 tinyint," + @@ -80,11 +85,11 @@ func prepareEnv() error { return nil } -func cleanEnv() error { +func cleanEnv(db string) error { var err error time.Sleep(2 * time.Second) steps := []string{ - "drop database if exists test_ws_stmt", + "drop database if exists " + db, } for _, step := range steps { err = doRequest(step) @@ -151,19 +156,20 @@ func query(payload string) (*common.TDEngineRestfulResp, error) { // @date: 2023/10/13 11:35 // @description: test stmt over websocket func TestStmt(t *testing.T) { - err := prepareEnv() + err := prepareEnv("test_ws_stmt") if err != nil { t.Error(err) return } - defer cleanEnv() + defer cleanEnv("test_ws_stmt") now := time.Now() - config := NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("test_ws_stmt") config.SetMessageTimeout(common.DefaultMessageTimeout) config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) config.SetErrorHandler(func(connector *Connector, err error) { t.Log(err) }) @@ -614,3 +620,496 @@ func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, e } return &result, nil } + +func TestSTMTQuery(t *testing.T) { + err := prepareEnv("test_ws_stmt_query") + if err != nil { + t.Error(err) + return + } + defer cleanEnv("test_ws_stmt_query") + now := time.Now() + config := NewConfig("ws://127.0.0.1:6041", 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetConnectDB("test_ws_stmt_query") + config.SetMessageTimeout(common.DefaultMessageTimeout) + config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + defer connector.Close() + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + err = stmt.Prepare("insert into ? using all_json tags(?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags(param.NewParam(1).AddJson([]byte(`{"tb":1}`)), param.NewColumnType(1).AddJson(0)) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Prepare("select * from all_json where ts >=? order by ts") + assert.NoError(t, err) + queryTime := now.Format(time.RFC3339Nano) + params = []*param.Param{param.NewParam(1).AddBinary([]byte(queryTime))} + paramTypes = param.NewColumnType(1).AddBinary(len(queryTime)) + err = stmt.BindParam(params, paramTypes) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Exec() + assert.NoError(t, err) + rows, err := stmt.UseResult() + assert.NoError(t, err) + columns := rows.Columns() + assert.Equal(t, 15, len(columns)) + expectColumns := []string{ + "ts", + "c1", + "c2", + "c3", + "c4", + "c5", + "c6", + "c7", + "c8", + "c9", + "c10", + "c11", + "c12", + "c13", + "t", + } + for i := 0; i < 14; i++ { + assert.Equal(t, columns[i], expectColumns[i]) + rows.ColumnTypeDatabaseTypeName(i) + rows.ColumnTypeLength(i) + rows.ColumnTypeScanType(i) + } + var result [][]driver.Value + for { + values := make([]driver.Value, 15) + err = rows.Next(values) + if err != nil { + if err == io.EOF { + break + } + assert.NoError(t, err) + } + result = append(result, values) + } + assert.Equal(t, 3, len(result)) + row1 := result[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, []byte(`{"tb":1}`), row1[14]) + row2 := result[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"tb":1}`), row2[14]) + row3 := result[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, []byte(`{"tb":1}`), row3[14]) + } + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + err = stmt.Prepare("insert into ? using all_all tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + + err = stmt.SetTableName("tb2") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags( + param.NewParam(14). + AddTimestamp(now, 0). + AddBool(true). + AddTinyint(2). + AddSmallint(2). + AddInt(2). + AddBigint(2). + AddUTinyint(2). + AddUSmallint(2). + AddUInt(2). + AddUBigint(2). + AddFloat(2). + AddDouble(2). + AddBinary([]byte("tb2")). + AddNchar("tb2"), + param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + ) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Prepare("select * from all_all where ts >=? order by ts") + assert.NoError(t, err) + queryTime := now.Format(time.RFC3339Nano) + params = []*param.Param{param.NewParam(1).AddBinary([]byte(queryTime))} + paramTypes = param.NewColumnType(1).AddBinary(len(queryTime)) + err = stmt.BindParam(params, paramTypes) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Exec() + assert.NoError(t, err) + rows, err := stmt.UseResult() + assert.NoError(t, err) + columns := rows.Columns() + assert.Equal(t, 28, len(columns)) + var result [][]driver.Value + for { + values := make([]driver.Value, 28) + err = rows.Next(values) + if err != nil { + if err == io.EOF { + break + } + assert.NoError(t, err) + } + result = append(result, values) + } + assert.Equal(t, 3, len(result)) + row1 := result[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row2 := result[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row3 := result[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, now.UnixNano()/1e6, row3[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[15]) + assert.Equal(t, int8(2), row3[16]) + assert.Equal(t, int16(2), row3[17]) + assert.Equal(t, int32(2), row3[18]) + assert.Equal(t, int64(2), row3[19]) + assert.Equal(t, uint8(2), row3[20]) + assert.Equal(t, uint16(2), row3[21]) + assert.Equal(t, uint32(2), row3[22]) + assert.Equal(t, uint64(2), row3[23]) + assert.Equal(t, float32(2), row3[24]) + assert.Equal(t, float64(2), row3[25]) + assert.Equal(t, "tb2", row3[26]) + assert.Equal(t, "tb2", row3[27]) + } +} + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func TestSTMTReconnect(t *testing.T) { + port := "36042" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + config := NewConfig("ws://127.0.0.1:"+port, 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetMessageTimeout(3 * time.Second) + config.SetWriteWait(3 * time.Second) + config.SetEnableCompression(true) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + config.SetAutoReconnect(true) + config.SetReconnectRetryCount(3) + config.SetReconnectIntervalMs(2000) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + stmt, err := connector.Init() + assert.NoError(t, err) + stmt.Close() + stopTaosadapter(cmd) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 3) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } + }() + stmt, err = connector.Init() + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + stmt, err = connector.Init() + assert.NoError(t, err) + stmt.Close() +} diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 88ed25b..e119dcf 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -2,10 +2,7 @@ package tmq import ( "errors" - "fmt" "time" - - "github.com/taosdata/driver-go/v3/common/tmq" ) type config struct { @@ -22,6 +19,10 @@ type config struct { AutoCommitIntervalMS string SnapshotEnable string WithTableName string + EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func newConfig(url string, chanLength uint) *config { @@ -31,110 +32,70 @@ func newConfig(url string, chanLength uint) *config { } } -func (c *config) setConnectUser(user tmq.ConfigValue) error { - var ok bool - c.User, ok = user.(string) - if !ok { - return fmt.Errorf("td.connect.user requires string got %T", user) - } - return nil +func (c *config) setConnectUser(user string) { + c.User = user } -func (c *config) setConnectPass(pass tmq.ConfigValue) error { - var ok bool - c.Password, ok = pass.(string) - if !ok { - return fmt.Errorf("td.connect.pass requires string got %T", pass) - } - return nil +func (c *config) setConnectPass(pass string) { + c.Password = pass } -func (c *config) setGroupID(groupID tmq.ConfigValue) error { - var ok bool - c.GroupID, ok = groupID.(string) - if !ok { - return fmt.Errorf("group.id requires string got %T", groupID) - } - return nil +func (c *config) setGroupID(groupID string) { + c.GroupID = groupID } -func (c *config) setClientID(clientID tmq.ConfigValue) error { - var ok bool - c.ClientID, ok = clientID.(string) - if !ok { - return fmt.Errorf("client.id requires string got %T", clientID) - } - return nil +func (c *config) setClientID(clientID string) { + c.ClientID = clientID } -func (c *config) setAutoOffsetReset(offsetReset tmq.ConfigValue) error { - var ok bool - c.OffsetRest, ok = offsetReset.(string) - if !ok { - return fmt.Errorf("auto.offset.reset requires string got %T", offsetReset) - } - return nil +func (c *config) setAutoOffsetReset(offsetReset string) { + c.OffsetRest = offsetReset } -func (c *config) setMessageTimeout(timeout tmq.ConfigValue) error { - var ok bool - c.MessageTimeout, ok = timeout.(time.Duration) - if !ok { - return fmt.Errorf("ws.message.timeout requires time.Duration got %T", timeout) - } - if c.MessageTimeout < time.Second { +func (c *config) setMessageTimeout(timeout time.Duration) error { + if timeout < time.Second { return errors.New("ws.message.timeout cannot be less than 1 second") } + c.MessageTimeout = timeout return nil } -func (c *config) setWriteWait(writeWait tmq.ConfigValue) error { - var ok bool - c.WriteWait, ok = writeWait.(time.Duration) - if !ok { - return fmt.Errorf("ws.message.writeWait requires time.Duration got %T", writeWait) - } - if c.WriteWait < time.Second { +func (c *config) setWriteWait(writeWait time.Duration) error { + if writeWait < time.Second { return errors.New("ws.message.writeWait cannot be less than 1 second") } - if c.WriteWait < 0 { - return errors.New("ws.message.writeWait cannot be less than 0") - } + c.WriteWait = writeWait return nil } -func (c *config) setAutoCommit(enable tmq.ConfigValue) error { - var ok bool - c.AutoCommit, ok = enable.(string) - if !ok { - return fmt.Errorf("enable.auto.commit requires string got %T", enable) - } - return nil +func (c *config) setAutoCommit(enable string) { + c.AutoCommit = enable } -func (c *config) setAutoCommitIntervalMS(autoCommitIntervalMS tmq.ConfigValue) error { - var ok bool - c.AutoCommitIntervalMS, ok = autoCommitIntervalMS.(string) - if !ok { - return fmt.Errorf("auto.commit.interval.ms requires string got %T", autoCommitIntervalMS) - } - return nil +func (c *config) setAutoCommitIntervalMS(autoCommitIntervalMS string) { + c.AutoCommitIntervalMS = autoCommitIntervalMS } -func (c *config) setSnapshotEnable(enableSnapshot tmq.ConfigValue) error { - var ok bool - c.SnapshotEnable, ok = enableSnapshot.(string) - if !ok { - return fmt.Errorf("experimental.snapshot.enable requires string got %T", enableSnapshot) - } - return nil +func (c *config) setSnapshotEnable(enableSnapshot string) { + c.SnapshotEnable = enableSnapshot } -func (c *config) setWithTableName(withTableName tmq.ConfigValue) error { - var ok bool - c.WithTableName, ok = withTableName.(string) - if !ok { - return fmt.Errorf("msg.with.table.name requires string got %T", withTableName) - } - return nil +func (c *config) setWithTableName(withTableName string) { + c.WithTableName = withTableName +} + +func (c *config) setEnableCompression(enableCompression bool) { + c.EnableCompression = enableCompression +} + +func (c *config) setAutoReconnect(autoReconnect bool) { + c.AutoReconnect = autoReconnect +} + +func (c *config) setReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *config) setReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount } diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index df1889b..64bb1ed 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,9 @@ import ( "encoding/binary" "errors" "fmt" + "net" + "net/url" + "strconv" "sync" "sync/atomic" "time" @@ -21,26 +24,33 @@ import ( ) type Consumer struct { - client *client.Client - requestID uint64 - err error - latestMessageID uint64 - listLock sync.RWMutex - sendChanList *list.List - messageTimeout time.Duration - url string - user string - password string - groupID string - clientID string - offsetRest string - autoCommit string - autoCommitIntervalMS string - snapshotEnable string - withTableName string - closeOnce sync.Once - closeChan chan struct{} - topics []string + client *client.Client + requestID uint64 + err error + dataParser *parser.TMQRawDataParser + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + autoCommit bool + autoCommitInterval time.Duration + nextAutoCommitTime time.Time + url string + user string + password string + groupID string + clientID string + offsetRest string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} + topics []string + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + chanLength uint + writeWait time.Duration + dialer *websocket.Dialer } type IndexedChan struct { @@ -63,37 +73,101 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { if err != nil { return nil, err } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + autoCommit := true + if config.AutoCommit == "false" { + autoCommit = false + } + autoCommitInterval := time.Second * 5 + if config.AutoCommitIntervalMS != "" { + interval, err := strconv.ParseUint(config.AutoCommitIntervalMS, 10, 64) + if err != nil { + return nil, err + } + autoCommitInterval = time.Millisecond * time.Duration(interval) + } + + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + u, err := url.Parse(config.Url) if err != nil { return nil, err } + u.Path = "/rest/tmq" + ws, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return nil, err + } + ws.EnableWriteCompression(config.EnableCompression) wsClient := client.NewClient(ws, config.ChanLength) - tmq := &Consumer{ - client: wsClient, - requestID: 0, - sendChanList: list.New(), - messageTimeout: config.MessageTimeout, - url: config.Url, - user: config.User, - password: config.Password, - groupID: config.GroupID, - clientID: config.ClientID, - offsetRest: config.OffsetRest, - autoCommit: config.AutoCommit, - autoCommitIntervalMS: config.AutoCommitIntervalMS, - snapshotEnable: config.SnapshotEnable, - withTableName: config.WithTableName, - closeChan: make(chan struct{}), - } - if config.WriteWait > 0 { - wsClient.WriteWait = config.WriteWait - } - wsClient.BinaryMessageHandler = tmq.handleBinaryMessage - wsClient.TextMessageHandler = tmq.handleTextMessage - wsClient.ErrorHandler = tmq.handleError - go wsClient.WritePump() - go wsClient.ReadPump() - return tmq, nil + + consumer := &Consumer{ + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: u.String(), + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: autoCommit, + autoCommitInterval: autoCommitInterval, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), + dataParser: parser.NewTMQRawDataParser(), + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + chanLength: config.ChanLength, + writeWait: config.WriteWait, + dialer: &dialer, + } + consumer.initClient(consumer.client) + return consumer, nil +} + +func (c *Consumer) initClient(client *client.Client) { + if c.writeWait > 0 { + client.WriteWait = c.writeWait + } + client.BinaryMessageHandler = c.handleBinaryMessage + client.TextMessageHandler = c.handleTextMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func (c *Consumer) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + if c.client != nil { + c.client.Close() + } + c.client = cl + if len(c.topics) > 0 { + err = c.doSubscribe(c.topics, false) + if err != nil { + c.client.Close() + continue + } + } + reconnected = true + break + } + if !reconnected { + return errors.New("reconnect failed") + } + return nil } func configMapToConfig(m *tmq.ConfigMap) (*config, error) { @@ -153,51 +227,44 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } - config := newConfig(url.(string), chanLen.(uint)) - err = config.setMessageTimeout(messageTimeout.(time.Duration)) - if err != nil { - return nil, err - } - err = config.setWriteWait(writeWait.(time.Duration)) - if err != nil { - return nil, err - } - err = config.setConnectUser(user) - if err != nil { - return nil, err - } - err = config.setConnectPass(pass) - if err != nil { - return nil, err - } - err = config.setGroupID(groupID) - if err != nil { - return nil, err - } - err = config.setClientID(clientID) + enableCompression, err := m.Get("ws.message.enableCompression", false) if err != nil { return nil, err } - err = config.setAutoOffsetReset(offsetReset) + autoReconnect, err := m.Get("ws.autoReconnect", false) if err != nil { return nil, err } - err = config.setAutoCommit(enableAutoCommit) + reconnectIntervalMs, err := m.Get("ws.reconnectIntervalMs", int(2000)) if err != nil { return nil, err } - err = config.setAutoCommitIntervalMS(autoCommitIntervalMS) + reconnectRetryCount, err := m.Get("ws.reconnectRetryCount", int(3)) if err != nil { return nil, err } - err = config.setSnapshotEnable(enableSnapshot) + config := newConfig(url.(string), chanLen.(uint)) + err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { return nil, err } - err = config.setWithTableName(withTableName) + err = config.setWriteWait(writeWait.(time.Duration)) if err != nil { return nil, err } + config.setConnectUser(user.(string)) + config.setConnectPass(pass.(string)) + config.setGroupID(groupID.(string)) + config.setClientID(clientID.(string)) + config.setAutoOffsetReset(offsetReset.(string)) + config.setAutoCommit(enableAutoCommit.(string)) + config.setAutoCommitIntervalMS(autoCommitIntervalMS.(string)) + config.setSnapshotEnable(enableSnapshot.(string)) + config.setWithTableName(withTableName.(string)) + config.setEnableCompression(enableCompression.(bool)) + config.setAutoReconnect(autoReconnect.(bool)) + config.setReconnectIntervalMs(reconnectIntervalMs.(int)) + config.setReconnectRetryCount(reconnectRetryCount.(int)) return config, nil } @@ -236,8 +303,9 @@ func (c *Consumer) handleBinaryMessage(message []byte) { } func (c *Consumer) handleError(err error) { - c.err = &WSError{err: err} - c.Close() + if !c.autoReconnect { + c.err = &WSError{err: err} + } } func (c *Consumer) generateReqID() uint64 { @@ -284,8 +352,7 @@ func (c *Consumer) findOutChanByID(index uint64) *list.Element { const ( TMQSubscribe = "subscribe" TMQPoll = "poll" - TMQFetch = "fetch" - TMQFetchBlock = "fetch_block" + TMQFetchRaw = "fetch_raw" TMQFetchJsonMeta = "fetch_json_meta" TMQCommit = "commit" TMQUnsubscribe = "unsubscribe" @@ -294,23 +361,31 @@ const ( TMQCommitOffset = "commit_offset" TMQCommitted = "committed" TMQPosition = "position" - TMQListTopics = "list_topics" ) var ClosedErr = errors.New("connection closed") func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { - if !c.client.IsRunning() { - c.client.PutEnvelope(envelope) - return nil, ClosedErr - } channel := &IndexedChan{ index: reqID, channel: make(chan []byte, 1), } element := c.addMessageOutChan(channel) envelope.Type = websocket.TextMessage - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.messageTimeout) defer cancel() select { @@ -333,22 +408,25 @@ func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error { } func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error { + return c.doSubscribe(topics, c.autoReconnect) +} + +func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { if c.err != nil { return c.err } reqID := c.generateReqID() req := &SubscribeReq{ - ReqID: reqID, - User: c.user, - Password: c.password, - GroupID: c.groupID, - ClientID: c.clientID, - OffsetRest: c.offsetRest, - Topics: topics, - AutoCommit: c.autoCommit, - AutoCommitIntervalMS: c.autoCommitIntervalMS, - SnapshotEnable: c.snapshotEnable, - WithTableName: c.withTableName, + ReqID: reqID, + User: c.user, + Password: c.password, + GroupID: c.groupID, + ClientID: c.clientID, + OffsetRest: c.offsetRest, + Topics: topics, + AutoCommit: "false", + SnapshotEnable: c.snapshotEnable, + WithTableName: c.withTableName, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -358,15 +436,30 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err Action: TMQSubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return err + if !reconnect { + return err + } + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return err + } + } else { + return err + } } var resp SubscribeResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -384,7 +477,17 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err // Poll messages func (c *Consumer) Poll(timeoutMs int) tmq.Event { if c.err != nil { - panic(c.err) + return tmq.NewTMQErrorWithErr(c.err) + } + if c.autoCommit { + if c.nextAutoCommitTime.IsZero() { + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } else { + if time.Now().After(c.nextAutoCommitTime) { + c.doCommit() + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } + } } reqID := c.generateReqID() req := &PollReq{ @@ -399,15 +502,30 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { Action: TMQPoll, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return tmq.NewTMQErrorWithErr(err) } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return tmq.NewTMQErrorWithErr(err) + if !c.autoReconnect { + return tmq.NewTMQErrorWithErr(err) + } + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + } else { + return tmq.NewTMQErrorWithErr(err) + } } var resp PollResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -417,7 +535,6 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { if resp.Code != 0 { panic(taosErrors.NewError(resp.Code, resp.Message)) } - c.latestMessageID = resp.MessageID if resp.HaveMessage { switch resp.MessageType { case common.TMQ_RES_DATA: @@ -500,10 +617,10 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { Action: TMQFetchJsonMeta, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -527,122 +644,91 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { } func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { - var tmqData []*tmq.Data - for { - reqID := c.generateReqID() - req := &FetchReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetch, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - var resp FetchResp - err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return nil, err - } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - if resp.Completed { - break - } - // fetch block - { - req := &FetchBlockReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetchBlock, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - block := respBytes[24:] - p := unsafe.Pointer(&block[0]) - data := parser.ReadBlock(p, resp.Rows, resp.FieldsTypes, resp.Precision) - tmqData = append(tmqData, &tmq.Data{ - TableName: resp.TableName, - Data: data, - }) + reqID := c.generateReqID() + req := &TMQFetchRawMetaReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetchRaw, + Args: args, + } + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + blockInfo, err := c.dataParser.Parse(unsafe.Pointer(&respBytes[38])) + if err != nil { + return nil, err + } + tmqData := make([]*tmq.Data, len(blockInfo)) + for i := 0; i < len(blockInfo); i++ { + tmqData[i] = &tmq.Data{ + TableName: blockInfo[i].TableName, + Data: parser.ReadBlockSimple(blockInfo[i].RawBlock, blockInfo[i].Precision), } } return tmqData, nil } func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { - return c.doCommit(c.latestMessageID) + err := c.doCommit() + if err != nil { + return nil, err + } + partitions, err := c.Assignment() + if err != nil { + return nil, err + } + return c.Committed(partitions, 0) } -func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { +func (c *Consumer) doCommit() error { if c.err != nil { - return nil, c.err + return c.err } reqID := c.generateReqID() req := &CommitReq{ ReqID: reqID, - MessageID: messageID, + MessageID: 0, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: TMQCommit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) - return nil, err + return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + return err } var resp CommitResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - partitions, err := c.Assignment() - if err != nil { - return nil, err + return taosErrors.NewError(resp.Code, resp.Message) } - return c.Committed(partitions, 0) + return nil } func (c *Consumer) Unsubscribe() error { @@ -661,10 +747,10 @@ func (c *Consumer) Unsubscribe() error { Action: TMQUnsubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -700,10 +786,10 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { Action: TMQGetTopicAssignment, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -750,10 +836,10 @@ func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) erro Action: TMQSeek, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -792,10 +878,10 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of Action: TMQCommitted, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -824,6 +910,8 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti if c.err != nil { return nil, c.err } + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) for i := 0; i < len(offsets); i++ { reqID := c.generateReqID() req := &CommitOffsetReq{ @@ -840,10 +928,9 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti Action: TMQCommitOffset, Args: args, } - envelope := c.client.GetEnvelope() + envelope.Reset() err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -883,10 +970,10 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi Action: TMQPosition, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index a7bf95e..37dd34b 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -1,10 +1,15 @@ package tmq import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -124,7 +129,7 @@ func TestConsumer(t *testing.T) { } }() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -266,7 +271,7 @@ func TestSeek(t *testing.T) { } defer cleanSeekEnv() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -276,8 +281,8 @@ func TestSeek(t *testing.T) { "client.id": "test_consumer", "auto.offset.reset": "earliest", "enable.auto.commit": "false", - "experimental.snapshot.enable": "false", "msg.with.table.name": "true", + "ws.message.enableCompression": true, }) if err != nil { t.Error(err) @@ -350,3 +355,623 @@ func TestSeek(t *testing.T) { assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) assert.GreaterOrEqual(t, partitions[0].Offset, messageOffset) } + +func prepareAutocommitEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + "create database test_ws_tmq_autocommit vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_autocommit_topic as database test_ws_tmq_autocommit", + "create table test_ws_tmq_autocommit.t1(ts timestamp,v int)", + "insert into test_ws_tmq_autocommit.t1 values (now,1)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanAutocommitEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestAutoCommit(t *testing.T) { + err := prepareAutocommitEnv() + if err != nil { + t.Error(err) + return + } + defer cleanAutocommitEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_autocommit_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + partitions, err := consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_autocommit_topic", *partitions[0].Topic) + assert.Equal(t, tmq.Offset(0), partitions[0].Offset) + + offset, err := consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.Equal(t, tmq.OffsetInvalid, offset[0].Offset) + + //poll + messageOffset := tmq.Offset(0) + haveMessage := false + for i := 0; i < 5; i++ { + event := consumer.Poll(500) + if event != nil { + haveMessage = true + messageOffset = event.(*tmq.DataMessage).Offset() + } + } + assert.True(t, haveMessage) + + offset, err = consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.GreaterOrEqual(t, offset[0].Offset, messageOffset) +} + +func prepareMultiBlockEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + "create database test_ws_tmq_multi_block vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_multi_block_topic as database test_ws_tmq_multi_block", + "create table test_ws_tmq_multi_block.t1(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t2(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t3(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t4(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t5(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t6(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t7(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t8(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t9(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t10(ts timestamp,v int)", + "insert into test_ws_tmq_multi_block.t1 values (now,1) test_ws_tmq_multi_block.t2 values (now,2) " + + "test_ws_tmq_multi_block.t3 values (now,3) test_ws_tmq_multi_block.t4 values (now,4)" + + "test_ws_tmq_multi_block.t5 values (now,5) test_ws_tmq_multi_block.t6 values (now,6)" + + "test_ws_tmq_multi_block.t7 values (now,7) test_ws_tmq_multi_block.t8 values (now,8)" + + "test_ws_tmq_multi_block.t9 values (now,9) test_ws_tmq_multi_block.t10 values (now,10)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanMultiBlockEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestMultiBlock(t *testing.T) { + err := prepareMultiBlockEnv() + assert.NoError(t, err) + defer cleanMultiBlockEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_multi_block_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_ws_tmq_multi_block", e.DBName()) + assert.Equal(t, 10, len(data)) + return + } + } +} + +func Test_configMapToConfigWrong(t *testing.T) { + type args struct { + m *tmq.ConfigMap + } + tests := []struct { + name string + args args + wantErr string + }{ + { + name: "url", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": 123, + }, + }, + wantErr: "ws.url expects type string, not int", + }, + { + name: "empty url", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "", + }, + }, + wantErr: "ws.url required", + }, + { + name: "channelLen", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": "not a uint", + }, + }, + wantErr: "ws.message.channelLen expects type uint, not string", + }, + { + name: "ws.message.timeout", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.timeout": "xx", + }, + }, + wantErr: "ws.message.timeout expects type time.Duration, not string", + }, + { + name: "ws.message.writeWait", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.writeWait": "xx", + }, + }, + wantErr: "ws.message.writeWait expects type time.Duration, not string", + }, + { + name: "td.connect.user", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "td.connect.user": 123, + }, + }, + wantErr: "td.connect.user expects type string, not int", + }, + { + name: "td.connect.pass", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "td.connect.pass": 123, + }, + }, + wantErr: "td.connect.pass expects type string, not int", + }, + { + name: "group.id", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "group.id": 123, + }, + }, + wantErr: "group.id expects type string, not int", + }, + { + name: "client.id", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "client.id": 123, + }, + }, + wantErr: "client.id expects type string, not int", + }, + { + name: "auto.offset.reset", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "auto.offset.reset": 123, + }, + }, + wantErr: "auto.offset.reset expects type string, not int", + }, + { + name: "enable.auto.commit", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "enable.auto.commit": 123, + }, + }, + wantErr: "enable.auto.commit expects type string, not int", + }, + { + name: "auto.commit.interval.ms", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "auto.commit.interval.ms": 123, + }, + }, + wantErr: "auto.commit.interval.ms expects type string, not int", + }, + { + name: "experimental.snapshot.enable", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "experimental.snapshot.enable": 123, + }, + }, + wantErr: "experimental.snapshot.enable expects type string, not int", + }, + { + name: "msg.with.table.name", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "msg.with.table.name": 123, + }, + }, + wantErr: "msg.with.table.name expects type string, not int", + }, + { + name: "ws.message.enableCompression", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.enableCompression": 123, + }, + }, + wantErr: "ws.message.enableCompression expects type bool, not int", + }, + { + name: "ws.message.timeout < 1s", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.timeout": time.Millisecond, + }, + }, + wantErr: "ws.message.timeout cannot be less than 1 second", + }, + { + name: "ws.message.writeWait < 1s", + args: args{ + m: &tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.writeWait": time.Millisecond, + }, + }, + wantErr: "ws.message.writeWait cannot be less than 1 second", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := configMapToConfig(tt.args.m) + assert.Nil(t, got) + assert.Equal(t, tt.wantErr, err.Error()) + }) + } +} + +func prepareMetaEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_meta_topic", + "drop database if exists test_ws_tmq_meta", + "create database test_ws_tmq_meta vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_meta_topic with meta as database test_ws_tmq_meta", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanMetaEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_meta_topic", + "drop database if exists test_ws_tmq_meta", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestMeta(t *testing.T) { + err := prepareMetaEnv() + assert.NoError(t, err) + defer cleanMetaEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + err = consumer.Subscribe("test_ws_tmq_meta_topic", nil) + assert.NoError(t, err) + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + go func() { + doRequest("create table test_ws_tmq_meta.st(ts timestamp,v int) tags (cn binary(20))") + doRequest("create table test_ws_tmq_meta.t1 using test_ws_tmq_meta.st tags ('t1')") + doRequest("insert into test_ws_tmq_meta.t1 values (now,1)") + doRequest("insert into test_ws_tmq_meta.t2 using test_ws_tmq_meta.st tags ('t1') values (now,2)") + time.Sleep(time.Second) + doRequest("insert into test_ws_tmq_meta.t1 values (now,1)") + doRequest("insert into test_ws_tmq_meta.t1 values (now,1)") + }() + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + t.Log(e) + assert.Equal(t, "test_ws_tmq_meta", e.DBName()) + case *tmq.MetaDataMessage: + assert.Equal(t, "test_ws_tmq_meta", e.DBName()) + assert.Equal(t, "test_ws_tmq_meta_topic", e.Topic()) + t.Log(e) + case *tmq.MetaMessage: + assert.Equal(t, "test_ws_tmq_meta", e.DBName()) + t.Log(e) + } + } +} + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func prepareSubReconnectEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + "create database test_ws_tmq_sub_reconnect vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_sub_reconnect_topic as database test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanSubReconnectEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestSubscribeReconnect(t *testing.T) { + port := "36043" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + assert.NoError(t, err) + defer func() { + stopTaosadapter(cmd) + }() + prepareSubReconnectEnv() + defer cleanSubReconnectEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:" + port, + "ws.message.channelLen": uint(0), + "ws.message.timeout": time.Second * 5, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + "ws.autoReconnect": true, + "ws.reconnectIntervalMs": 3000, + "ws.reconnectRetryCount": 3, + }) + assert.NoError(t, err) + stopTaosadapter(cmd) + time.Sleep(time.Second) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 3) + err = startTaosadapter(cmd, port) + if err != nil { + t.Error(err) + return + } + startChan <- struct{}{} + }() + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.NoError(t, err) + doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))") + doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')") + doRequest("insert into test_ws_tmq_sub_reconnect.t1 values (now,1)") + stopTaosadapter(cmd) + go func() { + time.Sleep(time.Second * 3) + startTaosadapter(cmd, port) + startChan <- struct{}{} + }() + time.Sleep(time.Second) + event := consumer.Poll(500) + assert.NotNil(t, event) + _, ok := event.(tmq.Error) + assert.True(t, ok) + <-startChan + haveMessage := false + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + t.Log(e) + assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName()) + haveMessage = true + break + default: + t.Log(e) + } + } + assert.True(t, haveMessage) +} diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go index d9b8c1d..3a17c8b 100644 --- a/ws/tmq/proto.go +++ b/ws/tmq/proto.go @@ -196,3 +196,8 @@ type PositionResp struct { Timing int64 `json:"timing"` Position []int64 `json:"position"` } + +type TMQFetchRawMetaReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +}