diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 6e4b5fe..0ea2913 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,34 +1,154 @@ name: Go -on: [ push,pull_request ] +on: + pull_request: + branches: + - 'main' + - '3.0' + workflow_dispatch: + inputs: + tbBranch: + description: 'TDengine branch' + required: true + type: string + +env: + SCCACHE_GHA_ENABLED: "true" jobs: build: - runs-on: ubuntu-latest - strategy: - matrix: - go: [ '1.14', '1.19' ] - name: Go ${{ matrix.go }} + runs-on: ubuntu-22.04 + name: Build + outputs: + commit_id: ${{ steps.get_commit_id.outputs.commit_id }} steps: - - name: checkout + - name: checkout TDengine by pr + if: github.event_name == 'pull_request' uses: actions/checkout@v3 with: - path: 'driver-go' - - name: checkout TDengine + repository: 'taosdata/TDengine' + path: 'TDengine' + ref: ${{ github.base_ref }} + + - name: checkout TDengine manually + if: github.event_name == 'workflow_dispatch' uses: actions/checkout@v3 with: repository: 'taosdata/TDengine' path: 'TDengine' - ref: 'main' + ref: ${{ inputs.tbBranch }} + + - name: get_commit_id + id: get_commit_id + run: | + cd TDengine + echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT + + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.3 + + - name: Cache server by pr + if: github.event_name == 'pull_request' + id: cache-server-pr + uses: actions/cache@v3 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.base_ref }}-${{ steps.get_commit_id.outputs.commit_id }} + + - name: Cache server manually + if: github.event_name == 'workflow_dispatch' + id: cache-server-manually + uses: actions/cache@v3 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ inputs.tbBranch }}-${{ steps.get_commit_id.outputs.commit_id }} + + - name: prepare install + if: > + (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: sudo apt install -y libgeos-dev - name: install TDengine + if: > + (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') run: | cd TDengine mkdir debug cd debug - cmake .. -DBUILD_JDBC=false -DBUILD_TOOLS=false -DBUILD_HTTP=false -DBUILD_TEST=off - make -j32 - sudo make install + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + make -j 4 + + - name: package + if: > + (github.event_name == 'workflow_dispatch' && steps.cache-server-manually.outputs.cache-hit != 'true') || + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: | + mkdir -p ./release + cp ./TDengine/debug/build/bin/taos ./release/ + cp ./TDengine/debug/build/bin/taosd ./release/ + cp ./TDengine/tools/taosadapter/taosadapter ./release/ + cp ./TDengine/debug/build/lib/libtaos.so.3.9.9.9 ./release/ + cp ./TDengine/debug/build/lib/librocksdb.so.8.1.1 ./release/ ||: + cp ./TDengine/include/client/taos.h ./release/ + cat >./release/install.sh<> $GITHUB_OUTPUT + + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.3 + + - name: Cache server + id: cache-server + uses: actions/cache@v3 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ github.ref_name }}-${{ steps.get_commit_id.outputs.commit_id }} + + - name: prepare install + if: steps.cache-server.outputs.cache-hit != 'true' + run: sudo apt install -y libgeos-dev + + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.3 + + - name: install TDengine + if: steps.cache-server.outputs.cache-hit != 'true' + run: | + cd TDengine + mkdir debug + cd debug + cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + make -j 4 + + - name: package + if: steps.cache-server.outputs.cache-hit != 'true' + run: | + mkdir -p ./release + cp ./TDengine/debug/build/bin/taos ./release/ + cp ./TDengine/debug/build/bin/taosd ./release/ + cp ./TDengine/tools/taosadapter/taosadapter ./release/ + cp ./TDengine/debug/build/lib/libtaos.so.3.9.9.9 ./release/ + cp ./TDengine/debug/build/lib/librocksdb.so.8.1.1 ./release/ ||: + cp ./TDengine/include/client/taos.h ./release/ + cat >./release/install.sh<start.sh<= record { + break + } + event := consumer.Poll(500) + if event != nil { + t.Log(event) + data := event.(*tmq.DataMessage).Value().([]*tmq.Data) + for _, datum := range data { + dataCount += len(datum.Data) + } + time.Sleep(time.Second * 2) + _, err = consumer.Commit() + assert.NoError(t, err) + } + } + assert.Equal(t, record, dataCount) + + //assignment after poll + assignment, err = consumer.Assignment() + t.Log(assignment) + assert.NoError(t, err) + assert.Equal(t, vgroups, len(assignment)) + for i := 0; i < len(assignment); i++ { + assert.Equal(t, topic, *assignment[i].Topic) + } + + // seek + for i := 0; i < len(assignment); i++ { + err = consumer.Seek(tmq.TopicPartition{ + Topic: &topic, + Partition: assignment[i].Partition, + Offset: 0, + }, 0) + assert.NoError(t, err) + } + + //assignment after seek + assignment, err = consumer.Assignment() + t.Log(assignment) + assert.NoError(t, err) + assert.Equal(t, vgroups, len(assignment)) + for i := 0; i < len(assignment); i++ { + assert.Equal(t, tmq.Offset(0), assignment[i].Offset) + assert.Equal(t, topic, *assignment[i].Topic) + } + + //poll after seek + dataCount = 0 + for i := 0; i < 20; i++ { + if dataCount >= record { + break + } + event := consumer.Poll(500) + if event != nil { + t.Log(event) + data := event.(*tmq.DataMessage).Value().([]*tmq.Data) + for _, datum := range data { + dataCount += len(datum.Data) + } + } + _, err = consumer.Commit() + assert.NoError(t, err) + } + assert.Equal(t, record, dataCount) + + //assignment after poll + assignment, err = consumer.Assignment() + t.Log(assignment) + assert.NoError(t, err) + assert.Equal(t, vgroups, len(assignment)) + for i := 0; i < len(assignment); i++ { + assert.Equal(t, topic, *assignment[i].Topic) + } + consumer.Close() +} + +func execWithoutResult(conn unsafe.Pointer, sql string) error { + result := wrapper.TaosQuery(conn, sql) + defer wrapper.TaosFreeResult(result) + code := wrapper.TaosError(result) + if code != 0 { + errStr := wrapper.TaosErrorStr(result) + wrapper.TaosFreeResult(result) + return &errors.TaosError{Code: int32(code), ErrStr: errStr} + } + return nil } diff --git a/bench/README.md b/bench/README.md index 9994143..2fea37d 100644 --- a/bench/README.md +++ b/bench/README.md @@ -7,4 +7,4 @@ go tool pprof memprofile.out ``` ```shell go tool pprof profile.out -``` \ No newline at end of file +``` diff --git a/common/column.go b/common/column.go index de5b840..aad229c 100644 --- a/common/column.go +++ b/common/column.go @@ -20,6 +20,7 @@ var ( NullTime = reflect.TypeOf(types.NullTime{}) NullBool = reflect.TypeOf(types.NullBool{}) NullString = reflect.TypeOf(types.NullString{}) + Bytes = reflect.TypeOf([]byte{}) NullJson = reflect.TypeOf(types.NullJson{}) UnknownType = reflect.TypeOf(new(interface{})).Elem() ) @@ -40,4 +41,6 @@ var ColumnTypeMap = map[int]reflect.Type{ TSDB_DATA_TYPE_NCHAR: NullString, TSDB_DATA_TYPE_TIMESTAMP: NullTime, TSDB_DATA_TYPE_JSON: NullJson, + TSDB_DATA_TYPE_VARBINARY: Bytes, + TSDB_DATA_TYPE_GEOMETRY: Bytes, } diff --git a/common/const.go b/common/const.go index 42ea950..ea2b7c9 100644 --- a/common/const.go +++ b/common/const.go @@ -23,87 +23,6 @@ const ( TSDB_OPTION_USE_ADAPTER ) -const ( - TSDB_DATA_TYPE_NULL = 0 // 1 bytes - TSDB_DATA_TYPE_BOOL = 1 // 1 bytes - TSDB_DATA_TYPE_TINYINT = 2 // 1 byte - TSDB_DATA_TYPE_SMALLINT = 3 // 2 bytes - TSDB_DATA_TYPE_INT = 4 // 4 bytes - TSDB_DATA_TYPE_BIGINT = 5 // 8 bytes - TSDB_DATA_TYPE_FLOAT = 6 // 4 bytes - TSDB_DATA_TYPE_DOUBLE = 7 // 8 bytes - TSDB_DATA_TYPE_BINARY = 8 // string - TSDB_DATA_TYPE_TIMESTAMP = 9 // 8 bytes - TSDB_DATA_TYPE_NCHAR = 10 // unicode string - TSDB_DATA_TYPE_UTINYINT = 11 // 1 byte - TSDB_DATA_TYPE_USMALLINT = 12 // 2 bytes - TSDB_DATA_TYPE_UINT = 13 // 4 bytes - TSDB_DATA_TYPE_UBIGINT = 14 // 8 bytes - TSDB_DATA_TYPE_JSON = 15 - TSDB_DATA_TYPE_VARBINARY = 16 - TSDB_DATA_TYPE_DECIMAL = 17 - TSDB_DATA_TYPE_BLOB = 18 - TSDB_DATA_TYPE_MEDIUMBLOB = 19 - TSDB_DATA_TYPE_MAX = 20 -) - -const ( - TSDB_DATA_TYPE_NULL_Str = "NULL" - TSDB_DATA_TYPE_BOOL_Str = "BOOL" - TSDB_DATA_TYPE_TINYINT_Str = "TINYINT" - TSDB_DATA_TYPE_SMALLINT_Str = "SMALLINT" - TSDB_DATA_TYPE_INT_Str = "INT" - TSDB_DATA_TYPE_BIGINT_Str = "BIGINT" - TSDB_DATA_TYPE_FLOAT_Str = "FLOAT" - TSDB_DATA_TYPE_DOUBLE_Str = "DOUBLE" - TSDB_DATA_TYPE_BINARY_Str = "VARCHAR" - TSDB_DATA_TYPE_TIMESTAMP_Str = "TIMESTAMP" - TSDB_DATA_TYPE_NCHAR_Str = "NCHAR" - TSDB_DATA_TYPE_UTINYINT_Str = "TINYINT UNSIGNED" - TSDB_DATA_TYPE_USMALLINT_Str = "SMALLINT UNSIGNED" - TSDB_DATA_TYPE_UINT_Str = "INT UNSIGNED" - TSDB_DATA_TYPE_UBIGINT_Str = "BIGINT UNSIGNED" - TSDB_DATA_TYPE_JSON_Str = "JSON" -) - -var TypeNameMap = map[int]string{ - TSDB_DATA_TYPE_NULL: TSDB_DATA_TYPE_NULL_Str, - TSDB_DATA_TYPE_BOOL: TSDB_DATA_TYPE_BOOL_Str, - TSDB_DATA_TYPE_TINYINT: TSDB_DATA_TYPE_TINYINT_Str, - TSDB_DATA_TYPE_SMALLINT: TSDB_DATA_TYPE_SMALLINT_Str, - TSDB_DATA_TYPE_INT: TSDB_DATA_TYPE_INT_Str, - TSDB_DATA_TYPE_BIGINT: TSDB_DATA_TYPE_BIGINT_Str, - TSDB_DATA_TYPE_FLOAT: TSDB_DATA_TYPE_FLOAT_Str, - TSDB_DATA_TYPE_DOUBLE: TSDB_DATA_TYPE_DOUBLE_Str, - TSDB_DATA_TYPE_BINARY: TSDB_DATA_TYPE_BINARY_Str, - TSDB_DATA_TYPE_TIMESTAMP: TSDB_DATA_TYPE_TIMESTAMP_Str, - TSDB_DATA_TYPE_NCHAR: TSDB_DATA_TYPE_NCHAR_Str, - TSDB_DATA_TYPE_UTINYINT: TSDB_DATA_TYPE_UTINYINT_Str, - TSDB_DATA_TYPE_USMALLINT: TSDB_DATA_TYPE_USMALLINT_Str, - TSDB_DATA_TYPE_UINT: TSDB_DATA_TYPE_UINT_Str, - TSDB_DATA_TYPE_UBIGINT: TSDB_DATA_TYPE_UBIGINT_Str, - TSDB_DATA_TYPE_JSON: TSDB_DATA_TYPE_JSON_Str, -} - -var NameTypeMap = map[string]int{ - TSDB_DATA_TYPE_NULL_Str: TSDB_DATA_TYPE_NULL, - TSDB_DATA_TYPE_BOOL_Str: TSDB_DATA_TYPE_BOOL, - TSDB_DATA_TYPE_TINYINT_Str: TSDB_DATA_TYPE_TINYINT, - TSDB_DATA_TYPE_SMALLINT_Str: TSDB_DATA_TYPE_SMALLINT, - TSDB_DATA_TYPE_INT_Str: TSDB_DATA_TYPE_INT, - TSDB_DATA_TYPE_BIGINT_Str: TSDB_DATA_TYPE_BIGINT, - TSDB_DATA_TYPE_FLOAT_Str: TSDB_DATA_TYPE_FLOAT, - TSDB_DATA_TYPE_DOUBLE_Str: TSDB_DATA_TYPE_DOUBLE, - TSDB_DATA_TYPE_BINARY_Str: TSDB_DATA_TYPE_BINARY, - TSDB_DATA_TYPE_TIMESTAMP_Str: TSDB_DATA_TYPE_TIMESTAMP, - TSDB_DATA_TYPE_NCHAR_Str: TSDB_DATA_TYPE_NCHAR, - TSDB_DATA_TYPE_UTINYINT_Str: TSDB_DATA_TYPE_UTINYINT, - TSDB_DATA_TYPE_USMALLINT_Str: TSDB_DATA_TYPE_USMALLINT, - TSDB_DATA_TYPE_UINT_Str: TSDB_DATA_TYPE_UINT, - TSDB_DATA_TYPE_UBIGINT_Str: TSDB_DATA_TYPE_UBIGINT, - TSDB_DATA_TYPE_JSON_Str: TSDB_DATA_TYPE_JSON, -} - const ( TMQ_RES_INVALID = -1 TMQ_RES_DATA = 1 @@ -141,3 +60,13 @@ const ( ) const ReqIDKey = "taos_req_id" + +const ( + TAOS_NOTIFY_PASSVER = 0 + TAOS_NOTIFY_WHITELIST_VER = 1 + TAOS_NOTIFY_USER_DROPPED = 2 +) + +const ( + TAOS_CONN_MODE_BI = 0 +) diff --git a/common/datatype.go b/common/datatype.go new file mode 100644 index 0000000..ea85688 --- /dev/null +++ b/common/datatype.go @@ -0,0 +1,299 @@ +package common + +import ( + "errors" + "reflect" +) + +type DBType struct { + IsVarData bool + ID int + Length int + Name string + ReflectType reflect.Type +} + +var NullType = DBType{ + ID: TSDB_DATA_TYPE_NULL, + Name: TSDB_DATA_TYPE_NULL_Str, + Length: 0, + ReflectType: UnknownType, + IsVarData: false, +} + +var BoolType = DBType{ + ID: TSDB_DATA_TYPE_BOOL, + Name: TSDB_DATA_TYPE_BOOL_Str, + Length: 1, + ReflectType: NullBool, + IsVarData: false, +} + +var TinyIntType = DBType{ + ID: TSDB_DATA_TYPE_TINYINT, + Name: TSDB_DATA_TYPE_TINYINT_Str, + Length: 1, + ReflectType: NullInt8, + IsVarData: false, +} + +var SmallIntType = DBType{ + ID: TSDB_DATA_TYPE_SMALLINT, + Name: TSDB_DATA_TYPE_SMALLINT_Str, + Length: 2, + ReflectType: NullInt16, + IsVarData: false, +} + +var IntType = DBType{ + ID: TSDB_DATA_TYPE_INT, + Name: TSDB_DATA_TYPE_INT_Str, + Length: 4, + ReflectType: NullInt32, + IsVarData: false, +} + +var BigIntType = DBType{ + ID: TSDB_DATA_TYPE_BIGINT, + Name: TSDB_DATA_TYPE_BIGINT_Str, + Length: 8, + ReflectType: NullInt64, + IsVarData: false, +} + +var UTinyIntType = DBType{ + ID: TSDB_DATA_TYPE_UTINYINT, + Name: TSDB_DATA_TYPE_UTINYINT_Str, + Length: 1, + ReflectType: NullUInt8, + IsVarData: false, +} + +var USmallIntType = DBType{ + ID: TSDB_DATA_TYPE_USMALLINT, + Name: TSDB_DATA_TYPE_USMALLINT_Str, + Length: 2, + ReflectType: NullUInt16, + IsVarData: false, +} + +var UIntType = DBType{ + ID: TSDB_DATA_TYPE_UINT, + Name: TSDB_DATA_TYPE_UINT_Str, + Length: 4, + ReflectType: NullUInt32, + IsVarData: false, +} + +var UBigIntType = DBType{ + ID: TSDB_DATA_TYPE_UBIGINT, + Name: TSDB_DATA_TYPE_UBIGINT_Str, + Length: 8, + ReflectType: NullUInt64, + IsVarData: false, +} + +var FloatType = DBType{ + ID: TSDB_DATA_TYPE_FLOAT, + Name: TSDB_DATA_TYPE_FLOAT_Str, + Length: 4, + ReflectType: NullFloat32, + IsVarData: false, +} + +var DoubleType = DBType{ + ID: TSDB_DATA_TYPE_DOUBLE, + Name: TSDB_DATA_TYPE_DOUBLE_Str, + Length: 8, + ReflectType: NullFloat64, + IsVarData: false, +} + +var BinaryType = DBType{ + ID: TSDB_DATA_TYPE_BINARY, + Name: TSDB_DATA_TYPE_BINARY_Str, + Length: 0, + ReflectType: NullString, + IsVarData: true, +} + +var NcharType = DBType{ + ID: TSDB_DATA_TYPE_NCHAR, + Name: TSDB_DATA_TYPE_NCHAR_Str, + Length: 0, + ReflectType: NullString, + IsVarData: true, +} + +var TimestampType = DBType{ + ID: TSDB_DATA_TYPE_TIMESTAMP, + Name: TSDB_DATA_TYPE_TIMESTAMP_Str, + Length: 8, + ReflectType: NullTime, + IsVarData: false, +} + +var JsonType = DBType{ + ID: TSDB_DATA_TYPE_JSON, + Name: TSDB_DATA_TYPE_JSON_Str, + Length: 0, + ReflectType: NullJson, + IsVarData: true, +} + +var VarBinaryType = DBType{ + ID: TSDB_DATA_TYPE_VARBINARY, + Name: TSDB_DATA_TYPE_VARBINARY_Str, + Length: 0, + ReflectType: NullString, + IsVarData: true, +} + +var GeometryType = DBType{ + ID: TSDB_DATA_TYPE_GEOMETRY, + Name: TSDB_DATA_TYPE_GEOMETRY_Str, + Length: 0, + ReflectType: NullString, + IsVarData: true, +} + +var allType = [21]*DBType{ + //TSDB_DATA_TYPE_NULL = 0 + &NullType, + //TSDB_DATA_TYPE_BOOL = 1 + &BoolType, + //TSDB_DATA_TYPE_TINYINT = 2 + &TinyIntType, + //TSDB_DATA_TYPE_SMALLINT = 3 + &SmallIntType, + //TSDB_DATA_TYPE_INT = 4 + &IntType, + //TSDB_DATA_TYPE_BIGINT = 5 + &BigIntType, + //TSDB_DATA_TYPE_FLOAT = 6 + &FloatType, + //TSDB_DATA_TYPE_DOUBLE = 7 + &DoubleType, + //TSDB_DATA_TYPE_BINARY = 8 + &BinaryType, + //TSDB_DATA_TYPE_TIMESTAMP = 9 + &TimestampType, + //TSDB_DATA_TYPE_NCHAR = 10 + &NcharType, + //TSDB_DATA_TYPE_UTINYINT = 11 + &UTinyIntType, + //TSDB_DATA_TYPE_USMALLINT = 12 + &USmallIntType, + //TSDB_DATA_TYPE_UINT = 13 + &UIntType, + //TSDB_DATA_TYPE_UBIGINT = 14 + &UBigIntType, + //TSDB_DATA_TYPE_JSON = 15 + &JsonType, + //TSDB_DATA_TYPE_VARBINARY = 16 + &VarBinaryType, + //TSDB_DATA_TYPE_DECIMAL = 17 + nil, + //TSDB_DATA_TYPE_BLOB = 18 + nil, + //TSDB_DATA_TYPE_MEDIUMBLOB = 19 + nil, + //TSDB_DATA_TYPE_GEOMETRY = 20 + &GeometryType, +} + +const ( + TSDB_DATA_TYPE_NULL = 0 // 1 bytes + TSDB_DATA_TYPE_BOOL = 1 // 1 bytes + TSDB_DATA_TYPE_TINYINT = 2 // 1 byte + TSDB_DATA_TYPE_SMALLINT = 3 // 2 bytes + TSDB_DATA_TYPE_INT = 4 // 4 bytes + TSDB_DATA_TYPE_BIGINT = 5 // 8 bytes + TSDB_DATA_TYPE_FLOAT = 6 // 4 bytes + TSDB_DATA_TYPE_DOUBLE = 7 // 8 bytes + TSDB_DATA_TYPE_BINARY = 8 // string + TSDB_DATA_TYPE_TIMESTAMP = 9 // 8 bytes + TSDB_DATA_TYPE_NCHAR = 10 // unicode string + TSDB_DATA_TYPE_UTINYINT = 11 // 1 byte + TSDB_DATA_TYPE_USMALLINT = 12 // 2 bytes + TSDB_DATA_TYPE_UINT = 13 // 4 bytes + TSDB_DATA_TYPE_UBIGINT = 14 // 8 bytes + TSDB_DATA_TYPE_JSON = 15 + TSDB_DATA_TYPE_VARBINARY = 16 + TSDB_DATA_TYPE_DECIMAL = 17 + TSDB_DATA_TYPE_BLOB = 18 + TSDB_DATA_TYPE_MEDIUMBLOB = 19 + TSDB_DATA_TYPE_GEOMETRY = 20 +) + +const ( + TSDB_DATA_TYPE_NULL_Str = "NULL" + TSDB_DATA_TYPE_BOOL_Str = "BOOL" + TSDB_DATA_TYPE_TINYINT_Str = "TINYINT" + TSDB_DATA_TYPE_SMALLINT_Str = "SMALLINT" + TSDB_DATA_TYPE_INT_Str = "INT" + TSDB_DATA_TYPE_BIGINT_Str = "BIGINT" + TSDB_DATA_TYPE_FLOAT_Str = "FLOAT" + TSDB_DATA_TYPE_DOUBLE_Str = "DOUBLE" + TSDB_DATA_TYPE_BINARY_Str = "VARCHAR" + TSDB_DATA_TYPE_TIMESTAMP_Str = "TIMESTAMP" + TSDB_DATA_TYPE_NCHAR_Str = "NCHAR" + TSDB_DATA_TYPE_UTINYINT_Str = "TINYINT UNSIGNED" + TSDB_DATA_TYPE_USMALLINT_Str = "SMALLINT UNSIGNED" + TSDB_DATA_TYPE_UINT_Str = "INT UNSIGNED" + TSDB_DATA_TYPE_UBIGINT_Str = "BIGINT UNSIGNED" + TSDB_DATA_TYPE_JSON_Str = "JSON" + TSDB_DATA_TYPE_VARBINARY_Str = "VARBINARY" + TSDB_DATA_TYPE_GEOMETRY_Str = "GEOMETRY" +) + +var TypeNameMap = map[int]string{ + TSDB_DATA_TYPE_NULL: TSDB_DATA_TYPE_NULL_Str, + TSDB_DATA_TYPE_BOOL: TSDB_DATA_TYPE_BOOL_Str, + TSDB_DATA_TYPE_TINYINT: TSDB_DATA_TYPE_TINYINT_Str, + TSDB_DATA_TYPE_SMALLINT: TSDB_DATA_TYPE_SMALLINT_Str, + TSDB_DATA_TYPE_INT: TSDB_DATA_TYPE_INT_Str, + TSDB_DATA_TYPE_BIGINT: TSDB_DATA_TYPE_BIGINT_Str, + TSDB_DATA_TYPE_FLOAT: TSDB_DATA_TYPE_FLOAT_Str, + TSDB_DATA_TYPE_DOUBLE: TSDB_DATA_TYPE_DOUBLE_Str, + TSDB_DATA_TYPE_BINARY: TSDB_DATA_TYPE_BINARY_Str, + TSDB_DATA_TYPE_TIMESTAMP: TSDB_DATA_TYPE_TIMESTAMP_Str, + TSDB_DATA_TYPE_NCHAR: TSDB_DATA_TYPE_NCHAR_Str, + TSDB_DATA_TYPE_UTINYINT: TSDB_DATA_TYPE_UTINYINT_Str, + TSDB_DATA_TYPE_USMALLINT: TSDB_DATA_TYPE_USMALLINT_Str, + TSDB_DATA_TYPE_UINT: TSDB_DATA_TYPE_UINT_Str, + TSDB_DATA_TYPE_UBIGINT: TSDB_DATA_TYPE_UBIGINT_Str, + TSDB_DATA_TYPE_JSON: TSDB_DATA_TYPE_JSON_Str, + TSDB_DATA_TYPE_VARBINARY: TSDB_DATA_TYPE_VARBINARY_Str, + TSDB_DATA_TYPE_GEOMETRY: TSDB_DATA_TYPE_GEOMETRY_Str, +} + +var NameTypeMap = map[string]int{ + TSDB_DATA_TYPE_NULL_Str: TSDB_DATA_TYPE_NULL, + TSDB_DATA_TYPE_BOOL_Str: TSDB_DATA_TYPE_BOOL, + TSDB_DATA_TYPE_TINYINT_Str: TSDB_DATA_TYPE_TINYINT, + TSDB_DATA_TYPE_SMALLINT_Str: TSDB_DATA_TYPE_SMALLINT, + TSDB_DATA_TYPE_INT_Str: TSDB_DATA_TYPE_INT, + TSDB_DATA_TYPE_BIGINT_Str: TSDB_DATA_TYPE_BIGINT, + TSDB_DATA_TYPE_FLOAT_Str: TSDB_DATA_TYPE_FLOAT, + TSDB_DATA_TYPE_DOUBLE_Str: TSDB_DATA_TYPE_DOUBLE, + TSDB_DATA_TYPE_BINARY_Str: TSDB_DATA_TYPE_BINARY, + TSDB_DATA_TYPE_TIMESTAMP_Str: TSDB_DATA_TYPE_TIMESTAMP, + TSDB_DATA_TYPE_NCHAR_Str: TSDB_DATA_TYPE_NCHAR, + TSDB_DATA_TYPE_UTINYINT_Str: TSDB_DATA_TYPE_UTINYINT, + TSDB_DATA_TYPE_USMALLINT_Str: TSDB_DATA_TYPE_USMALLINT, + TSDB_DATA_TYPE_UINT_Str: TSDB_DATA_TYPE_UINT, + TSDB_DATA_TYPE_UBIGINT_Str: TSDB_DATA_TYPE_UBIGINT, + TSDB_DATA_TYPE_JSON_Str: TSDB_DATA_TYPE_JSON, + TSDB_DATA_TYPE_VARBINARY_Str: TSDB_DATA_TYPE_VARBINARY, + TSDB_DATA_TYPE_GEOMETRY_Str: TSDB_DATA_TYPE_GEOMETRY, +} + +var NotSupportType = errors.New("not support type") + +func GetColType(colType int) (*DBType, error) { + if colType > len(allType) || colType < 0 { + return nil, NotSupportType + } + return allType[colType], nil +} diff --git a/common/param/column.go b/common/param/column.go index 8e5b68f..de5dedc 100644 --- a/common/param/column.go +++ b/common/param/column.go @@ -149,6 +149,18 @@ func (c *ColumnType) AddBinary(strMaxLen int) *ColumnType { return c } +func (c *ColumnType) AddVarBinary(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosVarBinaryType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + func (c *ColumnType) AddNchar(strMaxLen int) *ColumnType { if c.column >= c.size { return c @@ -184,6 +196,18 @@ func (c *ColumnType) AddJson(strMaxLen int) *ColumnType { return c } +func (c *ColumnType) AddGeometry(strMaxLen int) *ColumnType { + if c.column >= c.size { + return c + } + c.value[c.column] = &types.ColumnType{ + Type: types.TaosGeometryType, + MaxLen: strMaxLen, + } + c.column += 1 + return c +} + func (c *ColumnType) GetValue() ([]*types.ColumnType, error) { if c.size != c.column { return nil, fmt.Errorf("incomplete column expect %d columns set %d columns", c.size, c.column) diff --git a/common/param/param.go b/common/param/param.go index a9ec02d..a14854b 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -111,6 +111,13 @@ func (p *Param) SetBinary(offset int, value []byte) { p.value[offset] = taosTypes.TaosBinary(value) } +func (p *Param) SetVarBinary(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosVarBinary(value) +} + func (p *Param) SetNchar(offset int, value string) { if offset >= p.size { return @@ -135,6 +142,13 @@ func (p *Param) SetJson(offset int, value []byte) { p.value[offset] = taosTypes.TaosJson(value) } +func (p *Param) SetGeometry(offset int, value []byte) { + if offset >= p.size { + return + } + p.value[offset] = taosTypes.TaosGeometry(value) +} + func (p *Param) AddBool(value bool) *Param { if p.offset >= p.size { return p @@ -252,6 +266,15 @@ func (p *Param) AddBinary(value []byte) *Param { return p } +func (p *Param) AddVarBinary(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosVarBinary(value) + p.offset += 1 + return p +} + func (p *Param) AddNchar(value string) *Param { if p.offset >= p.size { return p @@ -282,6 +305,15 @@ func (p *Param) AddJson(value []byte) *Param { return p } +func (p *Param) AddGeometry(value []byte) *Param { + if p.offset >= p.size { + return p + } + p.value[p.offset] = taosTypes.TaosGeometry(value) + p.offset += 1 + return p +} + func (p *Param) GetValues() []driver.Value { return p.value } diff --git a/common/parser/block.go b/common/parser/block.go index f573804..7228e90 100644 --- a/common/parser/block.go +++ b/common/parser/block.go @@ -6,6 +6,7 @@ import ( "unsafe" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/pointer" ) const ( @@ -33,27 +34,27 @@ const ( ) func RawBlockGetVersion(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockVersionOffset))) + return *((*int32)(pointer.AddUintptr(rawBlock, RawBlockVersionOffset))) } func RawBlockGetLength(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + RawBlockLengthOffset))) + return *((*int32)(pointer.AddUintptr(rawBlock, RawBlockLengthOffset))) } func RawBlockGetNumOfRows(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfRowsOffset))) + return *((*int32)(pointer.AddUintptr(rawBlock, NumOfRowsOffset))) } func RawBlockGetNumOfCols(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + NumOfColsOffset))) + return *((*int32)(pointer.AddUintptr(rawBlock, NumOfColsOffset))) } func RawBlockGetHasColumnSegment(rawBlock unsafe.Pointer) int32 { - return *((*int32)(unsafe.Pointer(uintptr(rawBlock) + HasColumnSegmentOffset))) + return *((*int32)(pointer.AddUintptr(rawBlock, HasColumnSegmentOffset))) } func RawBlockGetGroupID(rawBlock unsafe.Pointer) uint64 { - return *((*uint64)(unsafe.Pointer(uintptr(rawBlock) + GroupIDOffset))) + return *((*uint64)(pointer.AddUintptr(rawBlock, GroupIDOffset))) } type RawBlockColInfo struct { @@ -63,9 +64,9 @@ type RawBlockColInfo struct { func RawBlockGetColInfo(rawBlock unsafe.Pointer, infos []RawBlockColInfo) { for i := 0; i < len(infos); i++ { - offset := uintptr(rawBlock) + ColInfoOffset + ColInfoSize*uintptr(i) - infos[i].ColType = *((*int8)(unsafe.Pointer(offset))) - infos[i].Bytes = *((*int32)(unsafe.Pointer(offset + Int8Size))) + offset := ColInfoOffset + ColInfoSize*uintptr(i) + infos[i].ColType = *((*int8)(pointer.AddUintptr(rawBlock, offset))) + infos[i].Bytes = *((*int32)(pointer.AddUintptr(rawBlock, offset+Int8Size))) } } @@ -80,7 +81,11 @@ func RawBlockGetColDataOffset(colCount int) uintptr { type FormatTimeFunc func(ts int64, precision int) driver.Value func IsVarDataType(colType uint8) bool { - return colType == common.TSDB_DATA_TYPE_BINARY || colType == common.TSDB_DATA_TYPE_NCHAR || colType == common.TSDB_DATA_TYPE_JSON + return colType == common.TSDB_DATA_TYPE_BINARY || + colType == common.TSDB_DATA_TYPE_NCHAR || + colType == common.TSDB_DATA_TYPE_JSON || + colType == common.TSDB_DATA_TYPE_VARBINARY || + colType == common.TSDB_DATA_TYPE_GEOMETRY } func BitmapLen(n int) int { @@ -99,9 +104,9 @@ func BMIsNull(c byte, n int) bool { return c&(1<<(7-BitPos(n))) == (1 << (7 - BitPos(n))) } -type rawConvertFunc func(pStart uintptr, row int, arg ...interface{}) driver.Value +type rawConvertFunc func(pStart unsafe.Pointer, row int, arg ...interface{}) driver.Value -type rawConvertVarDataFunc func(pHeader, pStart uintptr, row int) driver.Value +type rawConvertVarDataFunc func(pHeader, pStart unsafe.Pointer, row int) driver.Value var rawConvertFuncMap = map[uint8]rawConvertFunc{ uint8(common.TSDB_DATA_TYPE_BOOL): rawConvertBool, @@ -119,122 +124,145 @@ var rawConvertFuncMap = map[uint8]rawConvertFunc{ } var rawConvertVarDataMap = map[uint8]rawConvertVarDataFunc{ - uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, - uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, - uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, + uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, + uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, + uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, + uint8(common.TSDB_DATA_TYPE_VARBINARY): rawConvertVarBinary, + uint8(common.TSDB_DATA_TYPE_GEOMETRY): rawConvertGeometry, } -func ItemIsNull(pHeader uintptr, row int) bool { +func ItemIsNull(pHeader unsafe.Pointer, row int) bool { offset := CharOffset(row) - c := *((*byte)(unsafe.Pointer(pHeader + uintptr(offset)))) + c := *((*byte)(pointer.AddUintptr(pHeader, uintptr(offset)))) return BMIsNull(c, row) } -func rawConvertBool(pStart uintptr, row int, _ ...interface{}) driver.Value { - if (*((*byte)(unsafe.Pointer(pStart + uintptr(row)*1)))) != 0 { +func rawConvertBool(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + if (*((*byte)(pointer.AddUintptr(pStart, uintptr(row)*1)))) != 0 { return true } else { return false } } -func rawConvertTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int8)(unsafe.Pointer(pStart + uintptr(row)*Int8Size))) +func rawConvertTinyint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int8)(pointer.AddUintptr(pStart, uintptr(row)*Int8Size))) } -func rawConvertSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int16)(unsafe.Pointer(pStart + uintptr(row)*Int16Size))) +func rawConvertSmallint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int16)(pointer.AddUintptr(pStart, uintptr(row)*Int16Size))) } -func rawConvertInt(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int32)(unsafe.Pointer(pStart + uintptr(row)*Int32Size))) +func rawConvertInt(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int32)(pointer.AddUintptr(pStart, uintptr(row)*Int32Size))) } -func rawConvertBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))) +func rawConvertBigint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*int64)(pointer.AddUintptr(pStart, uintptr(row)*Int64Size))) } -func rawConvertUTinyint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint8)(unsafe.Pointer(pStart + uintptr(row)*UInt8Size))) +func rawConvertUTinyint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint8)(pointer.AddUintptr(pStart, uintptr(row)*UInt8Size))) } -func rawConvertUSmallint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint16)(unsafe.Pointer(pStart + uintptr(row)*UInt16Size))) +func rawConvertUSmallint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint16)(pointer.AddUintptr(pStart, uintptr(row)*UInt16Size))) } -func rawConvertUInt(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint32)(unsafe.Pointer(pStart + uintptr(row)*UInt32Size))) +func rawConvertUInt(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint32)(pointer.AddUintptr(pStart, uintptr(row)*UInt32Size))) } -func rawConvertUBigint(pStart uintptr, row int, _ ...interface{}) driver.Value { - return *((*uint64)(unsafe.Pointer(pStart + uintptr(row)*UInt64Size))) +func rawConvertUBigint(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return *((*uint64)(pointer.AddUintptr(pStart, uintptr(row)*UInt64Size))) } -func rawConvertFloat(pStart uintptr, row int, _ ...interface{}) driver.Value { - return math.Float32frombits(*((*uint32)(unsafe.Pointer(pStart + uintptr(row)*Float32Size)))) +func rawConvertFloat(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return math.Float32frombits(*((*uint32)(pointer.AddUintptr(pStart, uintptr(row)*Float32Size)))) } -func rawConvertDouble(pStart uintptr, row int, _ ...interface{}) driver.Value { - return math.Float64frombits(*((*uint64)(unsafe.Pointer(pStart + uintptr(row)*Float64Size)))) +func rawConvertDouble(pStart unsafe.Pointer, row int, _ ...interface{}) driver.Value { + return math.Float64frombits(*((*uint64)(pointer.AddUintptr(pStart, uintptr(row)*Float64Size)))) } -func rawConvertTime(pStart uintptr, row int, arg ...interface{}) driver.Value { +func rawConvertTime(pStart unsafe.Pointer, row int, arg ...interface{}) driver.Value { if len(arg) == 1 { - return common.TimestampConvertToTime(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) + return common.TimestampConvertToTime(*((*int64)(pointer.AddUintptr(pStart, uintptr(row)*Int64Size))), arg[0].(int)) } else if len(arg) == 2 { - return arg[1].(FormatTimeFunc)(*((*int64)(unsafe.Pointer(pStart + uintptr(row)*Int64Size))), arg[0].(int)) + return arg[1].(FormatTimeFunc)(*((*int64)(pointer.AddUintptr(pStart, uintptr(row)*Int64Size))), arg[0].(int)) } else { panic("convertTime error") } } -func rawConvertBinary(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) +func rawConvertVarBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) if offset == -1 { return nil } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) + currentRow := pointer.AddUintptr(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) currentRow = unsafe.Pointer(uintptr(currentRow) + 2) binaryVal := make([]byte, clen) - for index := int16(0); index < clen; index++ { + for index := uint16(0); index < clen; index++ { + binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + } + return binaryVal[:] +} + +func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { + return rawConvertVarBinary(pHeader, pStart, row) +} + +func rawConvertBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) + if offset == -1 { + return nil + } + currentRow := pointer.AddUintptr(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) + currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + + binaryVal := make([]byte, clen) + + for index := uint16(0); index < clen; index++ { binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) } return string(binaryVal[:]) } -func rawConvertNchar(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) +func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { + offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) if offset == -1 { return nil } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) / 4 + currentRow := pointer.AddUintptr(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) / 4 currentRow = unsafe.Pointer(uintptr(currentRow) + 2) binaryVal := make([]rune, clen) - for index := int16(0); index < clen; index++ { + for index := uint16(0); index < clen; index++ { binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) } return string(binaryVal) } -func rawConvertJson(pHeader, pStart uintptr, row int) driver.Value { - offset := *((*int32)(unsafe.Pointer(pHeader + uintptr(row*4)))) +func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { + offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) if offset == -1 { return nil } - currentRow := unsafe.Pointer(pStart + uintptr(offset)) - clen := *((*int16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) + currentRow := pointer.AddUintptr(pStart, uintptr(offset)) + clen := *((*uint16)(currentRow)) + currentRow = pointer.AddUintptr(currentRow, 2) binaryVal := make([]byte, clen) - for index := int16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + for index := uint16(0); index < clen; index++ { + binaryVal[index] = *((*byte)(pointer.AddUintptr(currentRow, uintptr(index)))) } return binaryVal[:] } @@ -245,13 +273,13 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision colCount := len(colTypes) nullBitMapOffset := uintptr(BitmapLen(blockSize)) lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - var pStart uintptr + pHeader := pointer.AddUintptr(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + Int32Size*uintptr(blockSize) + pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { r[row] = make([]driver.Value, colCount) @@ -260,7 +288,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision } } else { convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset + pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { r[row] = make([]driver.Value, colCount) @@ -272,7 +300,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision } } } - pHeader = pStart + uintptr(colLength) + pHeader = pointer.AddUintptr(pStart, uintptr(colLength)) } return r } @@ -281,24 +309,24 @@ func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, colCount := len(colTypes) nullBitMapOffset := uintptr(BitmapLen(blockSize)) lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - var pStart uintptr + pHeader := pointer.AddUintptr(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + Int32Size*uintptr(blockSize) + pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) dest[column] = convertF(pHeader, pStart, row) } else { convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset + pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) if ItemIsNull(pHeader, row) { dest[column] = nil } else { dest[column] = convertF(pStart, row, precision) } } - pHeader = pStart + uintptr(colLength) + pHeader = pointer.AddUintptr(pStart, uintptr(colLength)) } } @@ -307,13 +335,13 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin colCount := len(colTypes) nullBitMapOffset := uintptr(BitmapLen(blockSize)) lengthOffset := RawBlockGetColumnLengthOffset(colCount) - pHeader := uintptr(block) + RawBlockGetColDataOffset(colCount) - var pStart uintptr + pHeader := pointer.AddUintptr(block, RawBlockGetColDataOffset(colCount)) + var pStart unsafe.Pointer for column := 0; column < colCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { convertF := rawConvertVarDataMap[colTypes[column]] - pStart = pHeader + uintptr(4*blockSize) + pStart = pointer.AddUintptr(pHeader, uintptr(4*blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { r[row] = make([]driver.Value, colCount) @@ -322,7 +350,7 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin } } else { convertF := rawConvertFuncMap[colTypes[column]] - pStart = pHeader + nullBitMapOffset + pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { r[row] = make([]driver.Value, colCount) @@ -334,52 +362,19 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin } } } - pHeader = pStart + uintptr(colLength) + pHeader = pointer.AddUintptr(pStart, uintptr(colLength)) } return r } -func ItemRawBlock(colType uint8, pHeader, pStart uintptr, row int, precision int, timeFormat FormatTimeFunc) driver.Value { +func ItemRawBlock(colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) driver.Value { if IsVarDataType(colType) { - switch colType { - case uint8(common.TSDB_DATA_TYPE_BINARY): - return rawConvertBinary(pHeader, pStart, row) - case uint8(common.TSDB_DATA_TYPE_NCHAR): - return rawConvertNchar(pHeader, pStart, row) - case uint8(common.TSDB_DATA_TYPE_JSON): - return rawConvertJson(pHeader, pStart, row) - } + return rawConvertVarDataMap[colType](pHeader, pStart, row) } else { if ItemIsNull(pHeader, row) { return nil } else { - switch colType { - case uint8(common.TSDB_DATA_TYPE_BOOL): - return rawConvertBool(pStart, row) - case uint8(common.TSDB_DATA_TYPE_TINYINT): - return rawConvertTinyint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_SMALLINT): - return rawConvertSmallint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_INT): - return rawConvertInt(pStart, row) - case uint8(common.TSDB_DATA_TYPE_BIGINT): - return rawConvertBigint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UTINYINT): - return rawConvertUTinyint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_USMALLINT): - return rawConvertUSmallint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UINT): - return rawConvertUInt(pStart, row) - case uint8(common.TSDB_DATA_TYPE_UBIGINT): - return rawConvertUBigint(pStart, row) - case uint8(common.TSDB_DATA_TYPE_FLOAT): - return rawConvertFloat(pStart, row) - case uint8(common.TSDB_DATA_TYPE_DOUBLE): - return rawConvertDouble(pStart, row) - case uint8(common.TSDB_DATA_TYPE_TIMESTAMP): - return rawConvertTime(pStart, row, precision, timeFormat) - } + return rawConvertFuncMap[colType](pStart, row, precision, timeFormat) } } - return nil } diff --git a/common/parser/block_test.go b/common/parser/block_test.go index bd327de..5b7232a 100644 --- a/common/parser/block_test.go +++ b/common/parser/block_test.go @@ -9,10 +9,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/pointer" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper" ) +// @author: xftan +// @date: 2023/10/13 11:13 +// @description: test block func TestReadBlock(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -102,8 +106,8 @@ func TestReadBlock(t *testing.T) { return } precision := wrapper.TaosResultPrecision(res) - pHeaderList := make([]uintptr, fileCount) - pStartList := make([]uintptr, fileCount) + pHeaderList := make([]unsafe.Pointer, fileCount) + pStartList := make([]unsafe.Pointer, fileCount) var data [][]driver.Value for { blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) @@ -119,20 +123,20 @@ func TestReadBlock(t *testing.T) { } nullBitMapOffset := uintptr(BitmapLen(blockSize)) lengthOffset := RawBlockGetColumnLengthOffset(fileCount) - tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) - var tmpPStart uintptr + tmpPHeader := pointer.AddUintptr(block, RawBlockGetColDataOffset(fileCount)) + var tmpPStart unsafe.Pointer for column := 0; column < fileCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(rh.ColTypes[column]) { pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) + tmpPStart = pointer.AddUintptr(tmpPHeader, Int32Size*uintptr(blockSize)) pStartList[column] = tmpPStart } else { pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + nullBitMapOffset + tmpPStart = pointer.AddUintptr(tmpPHeader, nullBitMapOffset) pStartList[column] = tmpPStart } - tmpPHeader = tmpPStart + uintptr(colLength) + tmpPHeader = pointer.AddUintptr(tmpPStart, uintptr(colLength)) } for row := 0; row < blockSize; row++ { rowV := make([]driver.Value, fileCount) @@ -169,6 +173,9 @@ func TestReadBlock(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:13 +// @description: test block tag func TestBlockTag(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -244,8 +251,8 @@ func TestBlockTag(t *testing.T) { return } precision := wrapper.TaosResultPrecision(res) - pHeaderList := make([]uintptr, fileCount) - pStartList := make([]uintptr, fileCount) + pHeaderList := make([]unsafe.Pointer, fileCount) + pStartList := make([]unsafe.Pointer, fileCount) var data [][]driver.Value for { blockSize, errCode, block := wrapper.TaosFetchRawBlock(res) @@ -261,20 +268,20 @@ func TestBlockTag(t *testing.T) { } nullBitMapOffset := uintptr(BitmapLen(blockSize)) lengthOffset := RawBlockGetColumnLengthOffset(fileCount) - tmpPHeader := uintptr(block) + RawBlockGetColDataOffset(fileCount) // length i32, group u64 - var tmpPStart uintptr + tmpPHeader := pointer.AddUintptr(block, RawBlockGetColDataOffset(fileCount)) // length i32, group u64 + var tmpPStart unsafe.Pointer for column := 0; column < fileCount; column++ { - colLength := *((*int32)(unsafe.Pointer(uintptr(block) + lengthOffset + uintptr(column)*Int32Size))) + colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(rh.ColTypes[column]) { pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + Int32Size*uintptr(blockSize) + tmpPStart = pointer.AddUintptr(tmpPHeader, Int32Size*uintptr(blockSize)) pStartList[column] = tmpPStart } else { pHeaderList[column] = tmpPHeader - tmpPStart = tmpPHeader + nullBitMapOffset + tmpPStart = pointer.AddUintptr(tmpPHeader, nullBitMapOffset) pStartList[column] = tmpPStart } - tmpPHeader = tmpPStart + uintptr(colLength) + tmpPHeader = pointer.AddUintptr(tmpPStart, uintptr(colLength)) } for row := 0; row < blockSize; row++ { rowV := make([]driver.Value, fileCount) @@ -292,6 +299,9 @@ func TestBlockTag(t *testing.T) { t.Log(len(data[0][1].(string))) } +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test read row func TestReadRow(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -427,6 +437,9 @@ func TestReadRow(t *testing.T) { assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) } +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test read block with time format func TestReadBlockWithTimeFormat(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -560,6 +573,9 @@ func TestReadBlockWithTimeFormat(t *testing.T) { assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) } +// @author: xftan +// @date: 2023/10/13 11:18 +// @description: test parse block func TestParseBlock(t *testing.T) { conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -611,7 +627,9 @@ func TestParseBlock(t *testing.T) { "c10 float,"+ "c11 double,"+ "c12 binary(20),"+ - "c13 nchar(20)"+ + "c13 nchar(20),"+ + "c14 varbinary(20),"+ + "c15 geometry(100)"+ ") tags (info json)") code = wrapper.TaosError(res) if code != 0 { @@ -623,7 +641,9 @@ func TestParseBlock(t *testing.T) { wrapper.TaosFreeResult(res) now := time.Now() after1s := now.Add(time.Second) - sql := fmt.Sprintf("insert into parse_block.t0 using parse_block.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + sql := fmt.Sprintf("insert into parse_block.t0 using parse_block.all_type tags('{\"a\":1}') "+ + "values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar','test_varbinary','POINT(100 100)')"+ + "('%s',null,null,null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) res = wrapper.TaosQuery(conn, sql) code = wrapper.TaosError(res) if code != 0 { @@ -666,11 +686,11 @@ func TestParseBlock(t *testing.T) { version := RawBlockGetVersion(block) assert.Equal(t, int32(1), version) length := RawBlockGetLength(block) - assert.Equal(t, int32(374), length) + assert.Equal(t, int32(447), length) rows := RawBlockGetNumOfRows(block) assert.Equal(t, int32(2), rows) columns := RawBlockGetNumOfCols(block) - assert.Equal(t, int32(15), columns) + assert.Equal(t, int32(17), columns) hasColumnSegment := RawBlockGetHasColumnSegment(block) assert.Equal(t, int32(-2147483648), hasColumnSegment) groupId := RawBlockGetGroupID(block) @@ -736,6 +756,14 @@ func TestParseBlock(t *testing.T) { ColType: 10, Bytes: 82, }, + { + ColType: 16, + Bytes: 22, + }, + { + ColType: 20, + Bytes: 102, + }, { ColType: 15, Bytes: 16384, @@ -763,11 +791,13 @@ func TestParseBlock(t *testing.T) { assert.Equal(t, float64(1), row1[11].(float64)) assert.Equal(t, "test_binary", row1[12].(string)) assert.Equal(t, "test_nchar", row1[13].(string)) - assert.Equal(t, []byte(`{"a":1}`), row1[14].([]byte)) + assert.Equal(t, []byte("test_varbinary"), row1[14].([]byte)) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, row1[15].([]byte)) + assert.Equal(t, []byte(`{"a":1}`), row1[16].([]byte)) row2 := data[1] assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) - for i := 1; i < 14; i++ { + for i := 1; i < 16; i++ { assert.Nil(t, row2[i]) } - assert.Equal(t, []byte(`{"a":1}`), row2[14].([]byte)) + assert.Equal(t, []byte(`{"a":1}`), row2[16].([]byte)) } diff --git a/common/pointer/unsafe.go b/common/pointer/unsafe.go new file mode 100644 index 0000000..7e5e36f --- /dev/null +++ b/common/pointer/unsafe.go @@ -0,0 +1,7 @@ +package pointer + +import "unsafe" + +func AddUintptr(ptr unsafe.Pointer, len uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(ptr) + len) +} diff --git a/common/reqid.go b/common/reqid.go index f5d711f..02f1c72 100644 --- a/common/reqid.go +++ b/common/reqid.go @@ -8,6 +8,7 @@ import ( "unsafe" "github.com/google/uuid" + "github.com/taosdata/driver-go/v3/common/pointer" ) var tUUIDHashId int64 @@ -36,10 +37,9 @@ func murmurHash32(data []byte, seed uint32) uint32 { h1 := seed nBlocks := len(data) / 4 - p := uintptr(unsafe.Pointer(&data[0])) - p1 := p + uintptr(4*nBlocks) - for ; p < p1; p += 4 { - k1 := *(*uint32)(unsafe.Pointer(p)) + p := unsafe.Pointer(&data[0]) + for i := 0; i < nBlocks; i++ { + k1 := *(*uint32)(pointer.AddUintptr(p, uintptr(i*4))) k1 *= c1 k1 = bits.RotateLeft32(k1, 15) diff --git a/common/reqid_test.go b/common/reqid_test.go index 58e2fef..4afb056 100644 --- a/common/reqid_test.go +++ b/common/reqid_test.go @@ -24,10 +24,16 @@ func BenchmarkGetReqIDParallel(b *testing.B) { }) } +// @author: xftan +// @date: 2023/10/13 11:20 +// @description: test get req id func TestGetReqID(t *testing.T) { t.Log(GetReqID()) } +// @author: xftan +// @date: 2023/10/13 11:20 +// @description: test MurmurHash func TestMurmurHash(t *testing.T) { if murmurHash32([]byte("driver-go"), 0) != 3037880692 { t.Fatal("fail") diff --git a/common/serializer/block.go b/common/serializer/block.go index 4a23b53..03a53f2 100644 --- a/common/serializer/block.go +++ b/common/serializer/block.go @@ -366,6 +366,62 @@ func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte } lengthData = appendUint32(lengthData, uint32(length)) data = append(data, dataTmp...) + case taosTypes.TaosVarBinaryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_VARBINARY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosVarBinary) + if !is { + return nil, DataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) + case taosTypes.TaosGeometryType: + colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_GEOMETRY) + colInfoData = appendUint32(colInfoData, uint32(0)) + length := 0 + dataTmp := make([]byte, Int32Size*rows) + rowData := params[colIndex].GetValues() + for rowIndex := 0; rowIndex < rows; rowIndex++ { + offset := Int32Size * rowIndex + if rowData[rowIndex] == nil { + for i := 0; i < Int32Size; i++ { + // -1 + dataTmp[offset+i] = byte(255) + } + } else { + v, is := rowData[rowIndex].(taosTypes.TaosGeometry) + if !is { + return nil, DataTypeWrong + } + for i := 0; i < Int32Size; i++ { + dataTmp[offset+i] = byte(length >> (8 * i)) + } + dataTmp = appendUint16(dataTmp, uint16(len(v))) + dataTmp = append(dataTmp, v...) + length += len(v) + Int16Size + } + } + lengthData = appendUint32(lengthData, uint32(length)) + data = append(data, dataTmp...) case taosTypes.TaosNcharType: colInfoData = append(colInfoData, common.TSDB_DATA_TYPE_NCHAR) colInfoData = appendUint32(colInfoData, uint32(0)) diff --git a/common/serializer/block_test.go b/common/serializer/block_test.go index 6023592..c59d988 100644 --- a/common/serializer/block_test.go +++ b/common/serializer/block_test.go @@ -9,6 +9,9 @@ import ( "github.com/taosdata/driver-go/v3/common/param" ) +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test block func TestSerializeRawBlock(t *testing.T) { type args struct { params []*param.Param @@ -143,9 +146,55 @@ func TestSerializeRawBlock(t *testing.T) { param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + param.NewParam(3).AddVarBinary([]byte("test_varbinary")).AddNull().AddVarBinary([]byte("test_varbinary")), + param.NewParam(3).AddGeometry([]byte{ + 0x01, + 0x01, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + }).AddNull().AddGeometry([]byte{ + 0x01, + 0x01, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x59, + 0x40, + }), param.NewParam(3).AddJson([]byte("{\"a\":1}")).AddNull().AddJson([]byte("{\"a\":1}")), }, - colType: param.NewColumnType(15). + colType: param.NewColumnType(17). AddTimestamp(). AddBool(). AddTinyint(). @@ -160,13 +209,15 @@ func TestSerializeRawBlock(t *testing.T) { AddDouble(). AddBinary(0). AddNchar(0). + AddVarBinary(0). + AddGeometry(0). AddJson(0), }, want: []byte{ 0x01, 0x00, 0x00, 0x00, - 0xec, 0x01, 0x00, 0x00, + 0x64, 0x02, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, - 0x0f, 0x00, 0x00, 0x00, + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //types @@ -184,6 +235,8 @@ func TestSerializeRawBlock(t *testing.T) { 0x07, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, //lengths 0x18, 0x00, 0x00, 0x00, @@ -200,6 +253,8 @@ func TestSerializeRawBlock(t *testing.T) { 0x18, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, // ts 0x40, @@ -297,6 +352,24 @@ func TestSerializeRawBlock(t *testing.T) { 0x6e, 0x00, 0x00, 0x00, 0x63, 0x00, 0x00, 0x00, 0x68, 0x00, 0x00, 0x00, 0x61, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00, + //varbinary + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x10, 0x00, 0x00, 0x00, + 0x0e, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + 0x0e, 0x00, + 0x74, 0x65, 0x73, 0x74, 0x5f, 0x76, 0x61, 0x72, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + //geometry + 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0x17, 0x00, 0x00, 0x00, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + 0x15, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, + //json 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, diff --git a/common/stmt/field.go b/common/stmt/field.go new file mode 100644 index 0000000..20c51b9 --- /dev/null +++ b/common/stmt/field.go @@ -0,0 +1,56 @@ +package stmt + +import ( + "fmt" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/types" +) + +type StmtField struct { + Name string `json:"name"` + FieldType int8 `json:"field_type"` + Precision uint8 `json:"precision"` + Scale uint8 `json:"scale"` + Bytes int32 `json:"bytes"` +} + +func (s *StmtField) GetType() (*types.ColumnType, error) { + switch s.FieldType { + case common.TSDB_DATA_TYPE_BOOL: + return &types.ColumnType{Type: types.TaosBoolType}, nil + case common.TSDB_DATA_TYPE_TINYINT: + return &types.ColumnType{Type: types.TaosTinyintType}, nil + case common.TSDB_DATA_TYPE_SMALLINT: + return &types.ColumnType{Type: types.TaosSmallintType}, nil + case common.TSDB_DATA_TYPE_INT: + return &types.ColumnType{Type: types.TaosIntType}, nil + case common.TSDB_DATA_TYPE_BIGINT: + return &types.ColumnType{Type: types.TaosBigintType}, nil + case common.TSDB_DATA_TYPE_UTINYINT: + return &types.ColumnType{Type: types.TaosUTinyintType}, nil + case common.TSDB_DATA_TYPE_USMALLINT: + return &types.ColumnType{Type: types.TaosUSmallintType}, nil + case common.TSDB_DATA_TYPE_UINT: + return &types.ColumnType{Type: types.TaosUIntType}, nil + case common.TSDB_DATA_TYPE_UBIGINT: + return &types.ColumnType{Type: types.TaosUBigintType}, nil + case common.TSDB_DATA_TYPE_FLOAT: + return &types.ColumnType{Type: types.TaosFloatType}, nil + case common.TSDB_DATA_TYPE_DOUBLE: + return &types.ColumnType{Type: types.TaosDoubleType}, nil + case common.TSDB_DATA_TYPE_BINARY: + return &types.ColumnType{Type: types.TaosBinaryType}, nil + case common.TSDB_DATA_TYPE_VARBINARY: + return &types.ColumnType{Type: types.TaosVarBinaryType}, nil + case common.TSDB_DATA_TYPE_NCHAR: + return &types.ColumnType{Type: types.TaosNcharType}, nil + case common.TSDB_DATA_TYPE_TIMESTAMP: + return &types.ColumnType{Type: types.TaosTimestampType}, nil + case common.TSDB_DATA_TYPE_JSON: + return &types.ColumnType{Type: types.TaosJsonType}, nil + case common.TSDB_DATA_TYPE_GEOMETRY: + return &types.ColumnType{Type: types.TaosGeometryType}, nil + } + return nil, fmt.Errorf("unsupported type: %d, name %s", s.FieldType, s.Name) +} diff --git a/common/tmq/event.go b/common/tmq/event.go index e6a9e80..c62cc93 100644 --- a/common/tmq/event.go +++ b/common/tmq/event.go @@ -61,12 +61,15 @@ type Message interface { Topic() string DBName() string Value() interface{} + Offset() int64 } type DataMessage struct { - dbName string - topic string - data []*Data + TopicPartition TopicPartition + dbName string + topic string + data []*Data + offset Offset } func (m *DataMessage) String() string { @@ -86,6 +89,10 @@ func (m *DataMessage) SetData(data []*Data) { m.data = data } +func (m *DataMessage) SetOffset(offset Offset) { + m.offset = offset +} + func (m *DataMessage) Topic() string { return m.topic } @@ -98,11 +105,20 @@ func (m *DataMessage) Value() interface{} { return m.data } +func (m *DataMessage) Offset() Offset { + return m.offset +} + type MetaMessage struct { - dbName string - topic string - offset string - meta *Meta + TopicPartition TopicPartition + dbName string + topic string + offset Offset + meta *Meta +} + +func (m *MetaMessage) Offset() Offset { + return m.offset } func (m *MetaMessage) String() string { @@ -118,7 +134,7 @@ func (m *MetaMessage) SetTopic(topic string) { m.topic = topic } -func (m *MetaMessage) SetOffset(offset string) { +func (m *MetaMessage) SetOffset(offset Offset) { m.offset = offset } @@ -139,10 +155,15 @@ func (m *MetaMessage) Value() interface{} { } type MetaDataMessage struct { - dbName string - topic string - offset string - metaData *MetaData + TopicPartition TopicPartition + dbName string + topic string + offset Offset + metaData *MetaData +} + +func (m *MetaDataMessage) Offset() Offset { + return m.offset } func (m *MetaDataMessage) String() string { @@ -158,7 +179,7 @@ func (m *MetaDataMessage) SetTopic(topic string) { m.topic = topic } -func (m *MetaDataMessage) SetOffset(offset string) { +func (m *MetaDataMessage) SetOffset(offset Offset) { m.offset = offset } diff --git a/common/tmq/tmq.go b/common/tmq/tmq.go index a43607b..a6e5617 100644 --- a/common/tmq/tmq.go +++ b/common/tmq/tmq.go @@ -1,5 +1,7 @@ package tmq +import "fmt" + type Meta struct { Type string `json:"type"` TableName string `json:"tableName"` @@ -38,5 +40,48 @@ type CreateItem struct { Tags []*Tag `json:"tags"` } +type Offset int64 + +const OffsetInvalid = Offset(-2147467247) + +func (o Offset) String() string { + if o == OffsetInvalid { + return "unset" + } + return fmt.Sprintf("%d", int64(o)) +} + +func (o Offset) Valid() bool { + if o < 0 && o != OffsetInvalid { + return false + } + return true +} + type TopicPartition struct { + Topic *string + Partition int32 + Offset Offset + Metadata *string + Error error +} + +func (p TopicPartition) String() string { + topic := "" + if p.Topic != nil { + topic = *p.Topic + } + if p.Error != nil { + return fmt.Sprintf("%s[%d]@%s(%s)", + topic, p.Partition, p.Offset, p.Error) + } + return fmt.Sprintf("%s[%d]@%s", + topic, p.Partition, p.Offset) +} + +type Assignment struct { + VGroupID int32 `json:"vgroup_id"` + Offset int64 `json:"offset"` + Begin int64 `json:"begin"` + End int64 `json:"end"` } diff --git a/common/tmq/tmq_test.go b/common/tmq/tmq_test.go index c15999b..7279c1d 100644 --- a/common/tmq/tmq_test.go +++ b/common/tmq/tmq_test.go @@ -41,6 +41,9 @@ const dropJson = `{ "tableNameList":["t1", "t2"] }` +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test json func TestCreateJson(t *testing.T) { var obj Meta err := json.Unmarshal([]byte(createJson), &obj) @@ -51,6 +54,9 @@ func TestCreateJson(t *testing.T) { t.Log(obj) } +// @author: xftan +// @date: 2023/10/13 11:19 +// @description: test drop json func TestDropJson(t *testing.T) { var obj Meta err := json.Unmarshal([]byte(dropJson), &obj) diff --git a/errors/errors_test.go b/errors/errors_test.go index f6f7daa..e29f9ae 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -2,6 +2,9 @@ package errors import "testing" +// @author: xftan +// @date: 2023/10/13 11:20 +// @description: test new error func TestNewError(t *testing.T) { type args struct { code int diff --git a/examples/sqlstmt/main.go b/examples/sqlstmt/main.go new file mode 100644 index 0000000..1215d42 --- /dev/null +++ b/examples/sqlstmt/main.go @@ -0,0 +1,121 @@ +package main + +import ( + "database/sql" + "fmt" + "time" + + _ "github.com/taosdata/driver-go/v3/taosSql" +) + +var ( + driverName = "taosSql" + user = "root" + password = "taosdata" + host = "" + port = 6030 + dataSourceName = fmt.Sprintf("%s:%s@/tcp(%s:%d)/%s?interpolateParams=true", user, password, host, port, "") +) + +func main() { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + panic(err) + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver") + }() + _, err = db.Exec("create database if not exists test_stmt_driver") + if err != nil { + panic(err) + } + _, err = db.Exec("create table if not exists test_stmt_driver.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + panic(err) + } + stmt, err := db.Prepare("insert into test_stmt_driver.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + panic(err) + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + panic(err) + } + affected, err := result.RowsAffected() + if err != nil { + panic(err) + } + fmt.Println("affected", affected) + stmt.Close() + cr := 0 + err = db.QueryRow("select count(*) from test_stmt_driver.ct where ts = ?", now).Scan(&cr) + if err != nil { + panic(err) + } + fmt.Println("count", cr) + stmt, err = db.Prepare("select * from test_stmt_driver.ct where ts = ?") + if err != nil { + panic(err) + } + rows, err := stmt.Query(now) + if err != nil { + panic(err) + } + columns, err := rows.Columns() + if err != nil { + panic(err) + } + fmt.Println(columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + fmt.Println(ts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13) + } + fmt.Println("rows", count) +} diff --git a/examples/tmq/main.go b/examples/tmq/main.go index bf52a0e..eb110ce 100644 --- a/examples/tmq/main.go +++ b/examples/tmq/main.go @@ -35,7 +35,6 @@ func main() { "td.connect.port": "6030", "client.id": "test_tmq_client", "enable.auto.commit": "false", - "enable.heartbeat.background": "true", "experimental.snapshot.enable": "true", "msg.with.table.name": "true", }) @@ -55,17 +54,42 @@ func main() { panic(err) } for i := 0; i < 5; i++ { - ev := consumer.Poll(0) + ev := consumer.Poll(500) if ev != nil { switch e := ev.(type) { case *tmqcommon.DataMessage: - fmt.Println(e.Value()) + fmt.Printf("get message:%v\n", e) case tmqcommon.Error: fmt.Fprintf(os.Stderr, "%% Error: %v: %v\n", e.Code(), e) panic(e) } + consumer.Commit() } } + partitions, err := consumer.Assignment() + if err != nil { + panic(err) + } + for i := 0; i < len(partitions); i++ { + fmt.Println(partitions[i]) + err = consumer.Seek(tmqcommon.TopicPartition{ + Topic: partitions[i].Topic, + Partition: partitions[i].Partition, + Offset: 0, + }, 0) + if err != nil { + panic(err) + } + } + + partitions, err = consumer.Assignment() + if err != nil { + panic(err) + } + for i := 0; i < len(partitions); i++ { + fmt.Println(partitions[i]) + } + err = consumer.Close() if err != nil { panic(err) diff --git a/examples/tmqoverws/main.go b/examples/tmqoverws/main.go index 9d9eda3..b0fdf91 100644 --- a/examples/tmqoverws/main.go +++ b/examples/tmqoverws/main.go @@ -60,17 +60,42 @@ func main() { } }() for i := 0; i < 5; i++ { - ev := consumer.Poll(0) + ev := consumer.Poll(500) if ev != nil { switch e := ev.(type) { case *tmqcommon.DataMessage: - fmt.Printf("get message:%v", e) + fmt.Printf("get message:%v\n", e) case tmqcommon.Error: fmt.Printf("%% Error: %v: %v\n", e.Code(), e) panic(e) } + consumer.Commit() } } + partitions, err := consumer.Assignment() + if err != nil { + panic(err) + } + for i := 0; i < len(partitions); i++ { + fmt.Println(partitions[i]) + err = consumer.Seek(tmqcommon.TopicPartition{ + Topic: partitions[i].Topic, + Partition: partitions[i].Partition, + Offset: 0, + }, 0) + if err != nil { + panic(err) + } + } + + partitions, err = consumer.Assignment() + if err != nil { + panic(err) + } + for i := 0; i < len(partitions); i++ { + fmt.Println(partitions[i]) + } + err = consumer.Close() if err != nil { panic(err) diff --git a/taosRestful/driver_test.go b/taosRestful/driver_test.go index 77d6afa..c214d70 100644 --- a/taosRestful/driver_test.go +++ b/taosRestful/driver_test.go @@ -245,33 +245,33 @@ func TestChinese(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_chinese") + _, err = db.Exec("drop database if exists test_chinese_rest") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_chinese") + _, err = db.Exec("create database if not exists test_chinese_rest") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_chinese.chinese") + _, err = db.Exec("drop table if exists test_chinese_rest.chinese") if err != nil { t.Error(err) return } - _, err = db.Exec("create table if not exists test_chinese.chinese(ts timestamp,v nchar(32))") + _, err = db.Exec("create table if not exists test_chinese_rest.chinese(ts timestamp,v nchar(32))") if err != nil { t.Error(err) return } - _, err = db.Exec(`INSERT INTO test_chinese.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") + _, err = db.Exec(`INSERT INTO test_chinese_rest.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_chinese.chinese") + rows, err := db.Query("select * from test_chinese_rest.chinese") if err != nil { t.Error(err) return diff --git a/taosSql/connection.go b/taosSql/connection.go index 086ee09..288a74f 100644 --- a/taosSql/connection.go +++ b/taosSql/connection.go @@ -3,7 +3,6 @@ package taosSql import ( "context" "database/sql/driver" - errors2 "errors" "unsafe" "github.com/taosdata/driver-go/v3/common" @@ -49,6 +48,7 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { } locker.Lock() isInsert, code := wrapper.TaosStmtIsInsert(stmtP) + locker.Unlock() if code != 0 { errStr := wrapper.TaosStmtErrStr(stmtP) err := errors.NewError(code, errStr) @@ -57,13 +57,6 @@ func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { locker.Unlock() return nil, err } - if !isInsert { - locker.Lock() - wrapper.TaosStmtClose(stmtP) - locker.Unlock() - return nil, errors2.New("only supports insert statements") - } - locker.Unlock() stmt := &Stmt{ tc: tc, pSql: query, diff --git a/taosSql/connection_test.go b/taosSql/connection_test.go index 1d56107..f473ec5 100644 --- a/taosSql/connection_test.go +++ b/taosSql/connection_test.go @@ -8,6 +8,9 @@ import ( "github.com/taosdata/driver-go/v3/common" ) +// @author: xftan +// @date: 2023/10/13 11:21 +// @description: test taos connection exec context func TestTaosConn_ExecContext(t *testing.T) { ctx := context.WithValue(context.Background(), common.ReqIDKey, common.GetReqID()) db, err := sql.Open("taosSql", dataSourceName) diff --git a/taosSql/driver_test.go b/taosSql/driver_test.go index 9dfc33c..51a76be 100644 --- a/taosSql/driver_test.go +++ b/taosSql/driver_test.go @@ -309,43 +309,43 @@ func TestJson(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_json") + _, err = db.Exec("drop database if exists test_json_native") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_json") + _, err = db.Exec("create database if not exists test_json_native") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_json.tjson") + _, err = db.Exec("drop table if exists test_json_native.tjson") if err != nil { t.Error(err) return } - _, err = db.Exec("create stable if not exists test_json.tjson(ts timestamp,v int )tags(t json)") + _, err = db.Exec("create stable if not exists test_json_native.tjson(ts timestamp,v int )tags(t json)") if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tj_1 using test_json.tjson tags('{"a":1,"b":"b"}')values (now,1)`) + _, err = db.Exec(`insert into test_json_native.tj_1 using test_json_native.tjson tags('{"a":1,"b":"b"}')values (now,1)`) if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tj_2 using test_json.tjson tags('{"a":1,"c":"c"}')values (now,1)`) + _, err = db.Exec(`insert into test_json_native.tj_2 using test_json_native.tjson tags('{"a":1,"c":"c"}')values (now,1)`) if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tj_3 using test_json.tjson tags('null')values (now,1)`) + _, err = db.Exec(`insert into test_json_native.tj_3 using test_json_native.tjson tags('null')values (now,1)`) if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_json.tjson") + rows, err := db.Query("select * from test_json_native.tjson") if err != nil { t.Error(err) return @@ -385,38 +385,38 @@ func TestJsonSearch(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_json") + _, err = db.Exec("drop database if exists test_json_native_search") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_json") + _, err = db.Exec("create database if not exists test_json_native_search") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_json.tjson_search") + _, err = db.Exec("drop table if exists test_json_native_search.tjson_search") if err != nil { t.Error(err) return } - _, err = db.Exec("create stable if not exists test_json.tjson_search(ts timestamp,v int )tags(t json)") + _, err = db.Exec("create stable if not exists test_json_native_search.tjson_search(ts timestamp,v int )tags(t json)") if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tjs_1 using test_json.tjson_search tags('{"a":1,"b":"b"}')values (now,1)`) + _, err = db.Exec(`insert into test_json_native_search.tjs_1 using test_json_native_search.tjson_search tags('{"a":1,"b":"b"}')values (now,1)`) if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tjs_2 using test_json.tjson_search tags('{"a":1,"c":"c"}')values (now,2)`) + _, err = db.Exec(`insert into test_json_native_search.tjs_2 using test_json_native_search.tjson_search tags('{"a":1,"c":"c"}')values (now,2)`) if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_json.tjson_search where t contains 'a' and t->'b'='b' and v = 1") + rows, err := db.Query("select * from test_json_native_search.tjson_search where t contains 'a' and t->'b'='b' and v = 1") if err != nil { t.Error(err) return @@ -451,38 +451,38 @@ func TestJsonMatch(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_json") + _, err = db.Exec("drop database if exists test_json_native_match") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_json") + _, err = db.Exec("create database if not exists test_json_native_match") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_json.tjson_match") + _, err = db.Exec("drop table if exists test_json_native_match.tjson_match") if err != nil { t.Error(err) return } - _, err = db.Exec("create stable if not exists test_json.tjson_match(ts timestamp,v int )tags(t json)") + _, err = db.Exec("create stable if not exists test_json_native_match.tjson_match(ts timestamp,v int )tags(t json)") if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tjm_1 using test_json.tjson_match tags('{"a":1,"b":"b"}')values (now,1)`) + _, err = db.Exec(`insert into test_json_native_match.tjm_1 using test_json_native_match.tjson_match tags('{"a":1,"b":"b"}')values (now,1)`) if err != nil { t.Error(err) return } - _, err = db.Exec(`insert into test_json.tjm_2 using test_json.tjson_match tags('{"a":1,"c":"c"}')values (now,2)`) + _, err = db.Exec(`insert into test_json_native_match.tjm_2 using test_json_native_match.tjson_match tags('{"a":1,"c":"c"}')values (now,2)`) if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_json.tjson_match where t contains 'a' and t->'b' match '.*b.*|.*e.*' and v = 1") + rows, err := db.Query("select * from test_json_native_match.tjson_match where t contains 'a' and t->'b' match '.*b.*|.*e.*' and v = 1") if err != nil { t.Error(err) return @@ -516,33 +516,33 @@ func TestChinese(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_chinese") + _, err = db.Exec("drop database if exists test_chinese_native") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_chinese") + _, err = db.Exec("create database if not exists test_chinese_native") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_chinese.chinese") + _, err = db.Exec("drop table if exists test_chinese_native.chinese") if err != nil { t.Error(err) return } - _, err = db.Exec("create table if not exists test_chinese.chinese(ts timestamp,v nchar(32))") + _, err = db.Exec("create table if not exists test_chinese_native.chinese(ts timestamp,v nchar(32))") if err != nil { t.Error(err) return } - _, err = db.Exec(`INSERT INTO test_chinese.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") + _, err = db.Exec(`INSERT INTO test_chinese_native.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_chinese.chinese") + rows, err := db.Query("select * from test_chinese_native.chinese") if err != nil { t.Error(err) return diff --git a/taosSql/rows.go b/taosSql/rows.go index 685f66f..aaf684d 100644 --- a/taosSql/rows.go +++ b/taosSql/rows.go @@ -22,6 +22,7 @@ type rows struct { lengthList []int result unsafe.Pointer precision int + isStmt bool } func (rs *rows) Columns() []string { @@ -41,7 +42,16 @@ func (rs *rows) ColumnTypeScanType(i int) reflect.Type { } func (rs *rows) Close() error { - rs.freeResult() + if rs.handler != nil { + asyncHandlerPool.Put(rs.handler) + rs.handler = nil + } + if !rs.isStmt && rs.result != nil { + locker.Lock() + wrapper.TaosFreeResult(rs.result) + locker.Unlock() + } + rs.result = nil rs.block = nil return nil } @@ -82,6 +92,8 @@ func (rs *rows) Next(dest []driver.Value) error { } func (rs *rows) taosFetchBlock() error { + //rs.blockSize, rs.block = wrapper.TaosFetchBlock(rs.result) + //return nil result := rs.asyncFetchRows() if result.N == 0 { rs.blockSize = 0 @@ -107,16 +119,3 @@ func (rs *rows) asyncFetchRows() *handler.AsyncResult { r := <-rs.handler.Caller.FetchResult return r } - -func (rs *rows) freeResult() { - if rs.handler != nil { - asyncHandlerPool.Put(rs.handler) - rs.handler = nil - } - if rs.result != nil { - locker.Lock() - wrapper.TaosFreeResult(rs.result) - locker.Unlock() - rs.result = nil - } -} diff --git a/taosSql/statement.go b/taosSql/statement.go index f9fea69..e103a0e 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -2,7 +2,6 @@ package taosSql import ( "database/sql/driver" - errors2 "errors" "fmt" "reflect" "strconv" @@ -10,6 +9,7 @@ import ( "unsafe" "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/types" "github.com/taosdata/driver-go/v3/wrapper" @@ -23,8 +23,7 @@ type Stmt struct { tc *taosConn pSql string isInsert bool - cols []*wrapper.StmtField - //tags []*wrapper.StmtField + cols []*stmtCommon.StmtField } func (stmt *Stmt) Close() error { @@ -46,7 +45,7 @@ func (stmt *Stmt) NumInput() int { func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.tc == nil || stmt.tc.taos == nil { - return nil, errors.ErrTscInvalidConnection + return nil, driver.ErrBadConn } if len(args) != len(stmt.cols) { return nil, fmt.Errorf("stmt exec error: wrong number of parameters") @@ -73,42 +72,42 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { - return nil, errors2.New("unsupported") - //if stmt.tc == nil || stmt.tc.taos == nil { - // return nil, errors.ErrTscInvalidConnection - //} - //locker.Lock() - //defer locker.Unlock() - //code := wrapper.TaosStmtBindParam(stmt.stmt, args) - //if code != 0 { - // errStr := wrapper.TaosStmtErrStr(stmt.stmt) - // return nil, errors.NewError(code, errStr) - //} - //code = wrapper.TaosStmtAddBatch(stmt.stmt) - //if code != 0 { - // errStr := wrapper.TaosStmtErrStr(stmt.stmt) - // return nil, errors.NewError(code, errStr) - //} - //code = wrapper.TaosStmtExecute(stmt.stmt) - //if code != 0 { - // errStr := wrapper.TaosStmtErrStr(stmt.stmt) - // return nil, errors.NewError(code, errStr) - //} - //res := wrapper.TaosStmtUseResult(stmt.stmt) - //handler := asyncHandlerPool.Get() - //numFields := wrapper.TaosNumFields(res) - //rowsHeader, err := wrapper.ReadColumn(res, numFields) - //if err != nil { - // return nil, err - //} - //precision := wrapper.TaosResultPrecision(res) - //rs := &rows{ - // handler: handler, - // rowsHeader: rowsHeader, - // result: res, - // precision: precision, - //} - //return rs, nil + if stmt.tc == nil || stmt.tc.taos == nil { + return nil, driver.ErrBadConn + } + locker.Lock() + defer locker.Unlock() + code := wrapper.TaosStmtBindParam(stmt.stmt, args) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmt.stmt) + return nil, errors.NewError(code, errStr) + } + code = wrapper.TaosStmtAddBatch(stmt.stmt) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmt.stmt) + return nil, errors.NewError(code, errStr) + } + code = wrapper.TaosStmtExecute(stmt.stmt) + if code != 0 { + errStr := wrapper.TaosStmtErrStr(stmt.stmt) + return nil, errors.NewError(code, errStr) + } + res := wrapper.TaosStmtUseResult(stmt.stmt) + handler := asyncHandlerPool.Get() + numFields := wrapper.TaosNumFields(res) + rowsHeader, err := wrapper.ReadColumn(res, numFields) + if err != nil { + return nil, err + } + precision := wrapper.TaosResultPrecision(res) + rs := &rows{ + handler: handler, + rowsHeader: rowsHeader, + result: res, + precision: precision, + isStmt: true, + } + return rs, nil } func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { @@ -306,6 +305,26 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { default: return fmt.Errorf("CheckNamedValue:%v can not convert to binary", v) } + case common.TSDB_DATA_TYPE_VARBINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosVarBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosVarBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to varbinary", v) + } + + case common.TSDB_DATA_TYPE_GEOMETRY: + switch v.Value.(type) { + case string: + v.Value = types.TaosGeometry(v.Value.(string)) + case []byte: + v.Value = types.TaosGeometry(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to geometry", v) + } + case common.TSDB_DATA_TYPE_TIMESTAMP: t, is := v.Value.(time.Time) if is { @@ -481,7 +500,6 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { v.Value = types.TaosBinary(rv.Bytes()) } else { return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) - } default: return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) diff --git a/taosSql/statement_test.go b/taosSql/statement_test.go index 2442ff9..7230d78 100644 --- a/taosSql/statement_test.go +++ b/taosSql/statement_test.go @@ -11,6 +11,9 @@ import ( "github.com/stretchr/testify/assert" ) +// @author: xftan +// @date: 2023/10/13 11:22 +// @description: test stmt exec func TestStmtExec(t *testing.T) { db, err := sql.Open(driverName, dataSourceName) if err != nil { @@ -65,126 +68,129 @@ func TestStmtExec(t *testing.T) { assert.Equal(t, int64(1), affected) } -//func TestStmtQuery(t *testing.T) { -// db, err := sql.Open(driverName, dataSourceName) -// if err != nil { -// t.Error(err) -// return -// } -// defer db.Close() -// defer func() { -// db.Exec("drop database if exists test_stmt_driver") -// }() -// _, err = db.Exec("create database if not exists test_stmt_driver") -// if err != nil { -// t.Error(err) -// return -// } -// _, err = db.Exec("create table if not exists test_stmt_driver.ct(ts timestamp," + -// "c1 bool," + -// "c2 tinyint," + -// "c3 smallint," + -// "c4 int," + -// "c5 bigint," + -// "c6 tinyint unsigned," + -// "c7 smallint unsigned," + -// "c8 int unsigned," + -// "c9 bigint unsigned," + -// "c10 float," + -// "c11 double," + -// "c12 binary(20)," + -// "c13 nchar(20)" + -// ")") -// if err != nil { -// t.Error(err) -// return -// } -// stmt, err := db.Prepare("insert into test_stmt_driver.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") -// if err != nil { -// t.Error(err) -// return -// } -// now := time.Now() -// result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") -// if err != nil { -// t.Error(err) -// return -// } -// affected, err := result.RowsAffected() -// if err != nil { -// t.Error(err) -// return -// } -// assert.Equal(t, int64(1), affected) -// stmt.Close() -// stmt, err = db.Prepare("select * from test_stmt_driver.ct where ts = ?") -// if err != nil { -// t.Error(err) -// return -// } -// rows, err := stmt.Query(now) -// if err != nil { -// t.Error(err) -// return -// } -// columns, err := rows.Columns() -// if err != nil { -// t.Error(err) -// return -// } -// assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) -// count := 0 -// for rows.Next() { -// count += 1 -// var ( -// ts time.Time -// c1 bool -// c2 int8 -// c3 int16 -// c4 int32 -// c5 int64 -// c6 uint8 -// c7 uint16 -// c8 uint32 -// c9 uint64 -// c10 float32 -// c11 float64 -// c12 string -// c13 string -// ) -// err = rows.Scan(&ts, -// &c1, -// &c2, -// &c3, -// &c4, -// &c5, -// &c6, -// &c7, -// &c8, -// &c9, -// &c10, -// &c11, -// &c12, -// &c13) -// assert.NoError(t, err) -// assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) -// assert.Equal(t, true, c1) -// assert.Equal(t, int8(2), c2) -// assert.Equal(t, int16(3), c3) -// assert.Equal(t, int32(4), c4) -// assert.Equal(t, int64(5), c5) -// assert.Equal(t, uint8(6), c6) -// assert.Equal(t, uint16(7), c7) -// assert.Equal(t, uint32(8), c8) -// assert.Equal(t, uint64(9), c9) -// assert.Equal(t, float32(10), c10) -// assert.Equal(t, float64(11), c11) -// assert.Equal(t, "binary", c12) -// assert.Equal(t, "nchar", c13) -// } -// assert.Equal(t, 1, count) -//} +func TestStmtQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver_q") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_q.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + stmt.Close() + stmt, err = db.Prepare("select * from test_stmt_driver_q.ct where ts = ?") + if err != nil { + t.Error(err) + return + } + rows, err := stmt.Query(now) + if err != nil { + t.Error(err) + return + } + columns, err := rows.Columns() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + assert.NoError(t, err) + assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) + assert.Equal(t, true, c1) + assert.Equal(t, int8(2), c2) + assert.Equal(t, int16(3), c3) + assert.Equal(t, int32(4), c4) + assert.Equal(t, int64(5), c5) + assert.Equal(t, uint8(6), c6) + assert.Equal(t, uint16(7), c7) + assert.Equal(t, uint32(8), c8) + assert.Equal(t, uint64(9), c9) + assert.Equal(t, float32(10), c10) + assert.Equal(t, float64(11), c11) + assert.Equal(t, "binary", c12) + assert.Equal(t, "nchar", c13) + } + assert.Equal(t, 1, count) +} +// @author: xftan +// @date: 2023/10/13 11:22 +// @description: test stmt convert func TestStmtConvertExec(t *testing.T) { db, err := sql.Open(driverName, dataSourceName) if err != nil { @@ -1088,151 +1094,1066 @@ func TestStmtConvertExec(t *testing.T) { } } -//func TestStmtConvertQuery(t *testing.T) { -// db, err := sql.Open(driverName, dataSourceName) -// if err != nil { -// t.Error(err) -// return -// } -// defer db.Close() -// _, err = db.Exec("drop database if exists test_stmt_driver_convert_q") -// if err != nil { -// t.Error(err) -// return -// } -// defer func() { -// _, err = db.Exec("drop database if exists test_stmt_driver_convert_q") -// if err != nil { -// t.Error(err) -// return -// } -// }() -// _, err = db.Exec("create database test_stmt_driver_convert_q") -// if err != nil { -// t.Error(err) -// return -// } -// _, err = db.Exec("use test_stmt_driver_convert_q") -// if err != nil { -// t.Error(err) -// return -// } -// _, err = db.Exec("create table t0 (ts timestamp," + -// "c1 bool," + -// "c2 tinyint," + -// "c3 smallint," + -// "c4 int," + -// "c5 bigint," + -// "c6 tinyint unsigned," + -// "c7 smallint unsigned," + -// "c8 int unsigned," + -// "c9 bigint unsigned," + -// "c10 float," + -// "c11 double," + -// "c12 binary(20)," + -// "c13 nchar(20)" + -// ")") -// if err != nil { -// t.Error(err) -// return -// } -// now := time.Now() -// after1s := now.Add(time.Second) -// _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) -// if err != nil { -// t.Error(err) -// return -// } -// _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) -// if err != nil { -// t.Error(err) -// return -// } -// tests := []struct { -// name string -// field string -// where string -// bind interface{} -// expectNoValue bool -// expectValue driver.Value -// expectError bool -// }{ -// { -// name: "bool_true", -// field: "c1", -// where: "c1 = ?", -// bind: true, -// expectValue: true, -// }, -// { -// name: "bool_false", -// field: "c1", -// where: "c1 = ?", -// bind: false, -// expectNoValue: true, -// }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) -// -// stmt, err := db.Prepare(sql) -// if err != nil { -// t.Error(err) -// return -// } -// rows, err := stmt.Query(tt.bind) -// if tt.expectError { -// assert.NotNil(t, err) -// stmt.Close() -// return -// } -// if err != nil { -// t.Error(err) -// return -// } -// tts, err := rows.ColumnTypes() -// typesL := make([]reflect.Type, 1) -// for i, tp := range tts { -// st := tp.ScanType() -// if st == nil { -// t.Errorf("scantype is null for column %q", tp.Name()) -// continue -// } -// typesL[i] = st -// } -// var data []driver.Value -// for rows.Next() { -// values := make([]interface{}, 1) -// for i := range values { -// values[i] = reflect.New(typesL[i]).Interface() -// } -// err = rows.Scan(values...) -// if err != nil { -// t.Error(err) -// return -// } -// v, err := values[0].(driver.Valuer).Value() -// if err != nil { -// t.Error(err) -// } -// data = append(data, v) -// } -// if tt.expectNoValue { -// if len(data) > 0 { -// t.Errorf("expect no value got %#v", data) -// return -// } -// return -// } -// if len(data) != 1 { -// t.Errorf("expect %d got %d", 1, len(data)) -// return -// } -// if data[0] != tt.expectValue { -// t.Errorf("expect %v got %v", tt.expectValue, data[0]) -// return -// } -// }) -// } -//} +func TestStmtConvertQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_convert_q") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_convert_q") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table t0 (ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + now := time.Now() + after1s := now.Add(time.Second) + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + tests := []struct { + name string + field string + where string + bind interface{} + expectNoValue bool + expectValue driver.Value + expectError bool + }{ + //ts + { + name: "ts", + field: "ts", + where: "ts = ?", + bind: now, + expectValue: time.Unix(now.Unix(), int64((now.Nanosecond()/1e6)*1e6)).Local(), + }, + + //bool + { + name: "bool_true", + field: "c1", + where: "c1 = ?", + bind: true, + expectValue: true, + }, + { + name: "bool_false", + field: "c1", + where: "c1 = ?", + bind: false, + expectNoValue: true, + }, + { + name: "tinyint_int8", + field: "c2", + where: "c2 = ?", + bind: int8(2), + expectValue: int8(2), + }, + { + name: "tinyint_iny16", + field: "c2", + where: "c2 = ?", + bind: int16(2), + expectValue: int8(2), + }, + { + name: "tinyint_int32", + field: "c2", + where: "c2 = ?", + bind: int32(2), + expectValue: int8(2), + }, + { + name: "tinyint_int64", + field: "c2", + where: "c2 = ?", + bind: int64(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint8", + field: "c2", + where: "c2 = ?", + bind: uint8(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint16", + field: "c2", + where: "c2 = ?", + bind: uint16(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint32", + field: "c2", + where: "c2 = ?", + bind: uint32(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint64", + field: "c2", + where: "c2 = ?", + bind: uint64(2), + expectValue: int8(2), + }, + { + name: "tinyint_float32", + field: "c2", + where: "c2 = ?", + bind: float32(2), + expectValue: int8(2), + }, + { + name: "tinyint_float64", + field: "c2", + where: "c2 = ?", + bind: float64(2), + expectValue: int8(2), + }, + { + name: "tinyint_int", + field: "c2", + where: "c2 = ?", + bind: int(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint", + field: "c2", + where: "c2 = ?", + bind: uint(2), + expectValue: int8(2), + }, + + // smallint + { + name: "smallint_int8", + field: "c3", + where: "c3 = ?", + bind: int8(3), + expectValue: int16(3), + }, + { + name: "smallint_iny16", + field: "c3", + where: "c3 = ?", + bind: int16(3), + expectValue: int16(3), + }, + { + name: "smallint_int32", + field: "c3", + where: "c3 = ?", + bind: int32(3), + expectValue: int16(3), + }, + { + name: "smallint_int64", + field: "c3", + where: "c3 = ?", + bind: int64(3), + expectValue: int16(3), + }, + { + name: "smallint_uint8", + field: "c3", + where: "c3 = ?", + bind: uint8(3), + expectValue: int16(3), + }, + { + name: "smallint_uint16", + field: "c3", + where: "c3 = ?", + bind: uint16(3), + expectValue: int16(3), + }, + { + name: "smallint_uint32", + field: "c3", + where: "c3 = ?", + bind: uint32(3), + expectValue: int16(3), + }, + { + name: "smallint_uint64", + field: "c3", + where: "c3 = ?", + bind: uint64(3), + expectValue: int16(3), + }, + { + name: "smallint_float32", + field: "c3", + where: "c3 = ?", + bind: float32(3), + expectValue: int16(3), + }, + { + name: "smallint_float64", + field: "c3", + where: "c3 = ?", + bind: float64(3), + expectValue: int16(3), + }, + { + name: "smallint_int", + field: "c3", + where: "c3 = ?", + bind: int(3), + expectValue: int16(3), + }, + { + name: "smallint_uint", + field: "c3", + where: "c3 = ?", + bind: uint(3), + expectValue: int16(3), + }, + + //int + { + name: "int_int8", + field: "c4", + where: "c4 = ?", + bind: int8(4), + expectValue: int32(4), + }, + { + name: "int_iny16", + field: "c4", + where: "c4 = ?", + bind: int16(4), + expectValue: int32(4), + }, + { + name: "int_int32", + field: "c4", + where: "c4 = ?", + bind: int32(4), + expectValue: int32(4), + }, + { + name: "int_int64", + field: "c4", + where: "c4 = ?", + bind: int64(4), + expectValue: int32(4), + }, + { + name: "int_uint8", + field: "c4", + where: "c4 = ?", + bind: uint8(4), + expectValue: int32(4), + }, + { + name: "int_uint16", + field: "c4", + where: "c4 = ?", + bind: uint16(4), + expectValue: int32(4), + }, + { + name: "int_uint32", + field: "c4", + where: "c4 = ?", + bind: uint32(4), + expectValue: int32(4), + }, + { + name: "int_uint64", + field: "c4", + where: "c4 = ?", + bind: uint64(4), + expectValue: int32(4), + }, + { + name: "int_float32", + field: "c4", + where: "c4 = ?", + bind: float32(4), + expectValue: int32(4), + }, + { + name: "int_float64", + field: "c4", + where: "c4 = ?", + bind: float64(4), + expectValue: int32(4), + }, + { + name: "int_int", + field: "c4", + where: "c4 = ?", + bind: int(4), + expectValue: int32(4), + }, + { + name: "int_uint", + field: "c4", + where: "c4 = ?", + bind: uint(4), + expectValue: int32(4), + }, + + //bigint + { + name: "bigint_int8", + field: "c5", + where: "c5 = ?", + bind: int8(5), + expectValue: int64(5), + }, + { + name: "bigint_iny16", + field: "c5", + where: "c5 = ?", + bind: int16(5), + expectValue: int64(5), + }, + { + name: "bigint_int32", + field: "c5", + where: "c5 = ?", + bind: int32(5), + expectValue: int64(5), + }, + { + name: "bigint_int64", + field: "c5", + where: "c5 = ?", + bind: int64(5), + expectValue: int64(5), + }, + { + name: "bigint_uint8", + field: "c5", + where: "c5 = ?", + bind: uint8(5), + expectValue: int64(5), + }, + { + name: "bigint_uint16", + field: "c5", + where: "c5 = ?", + bind: uint16(5), + expectValue: int64(5), + }, + { + name: "bigint_uint32", + field: "c5", + where: "c5 = ?", + bind: uint32(5), + expectValue: int64(5), + }, + { + name: "bigint_uint64", + field: "c5", + where: "c5 = ?", + bind: uint64(5), + expectValue: int64(5), + }, + { + name: "bigint_float32", + field: "c5", + where: "c5 = ?", + bind: float32(5), + expectValue: int64(5), + }, + { + name: "bigint_float64", + field: "c5", + where: "c5 = ?", + bind: float64(5), + expectValue: int64(5), + }, + { + name: "bigint_int", + field: "c5", + where: "c5 = ?", + bind: int(5), + expectValue: int64(5), + }, + { + name: "bigint_uint", + field: "c5", + where: "c5 = ?", + bind: uint(5), + expectValue: int64(5), + }, + + //utinyint + { + name: "utinyint_int8", + field: "c6", + where: "c6 = ?", + bind: int8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_iny16", + field: "c6", + where: "c6 = ?", + bind: int16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int32", + field: "c6", + where: "c6 = ?", + bind: int32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int64", + field: "c6", + where: "c6 = ?", + bind: int64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint8", + field: "c6", + where: "c6 = ?", + bind: uint8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint16", + field: "c6", + where: "c6 = ?", + bind: uint16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint32", + field: "c6", + where: "c6 = ?", + bind: uint32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint64", + field: "c6", + where: "c6 = ?", + bind: uint64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float32", + field: "c6", + where: "c6 = ?", + bind: float32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float64", + field: "c6", + where: "c6 = ?", + bind: float64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int", + field: "c6", + where: "c6 = ?", + bind: int(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint", + field: "c6", + where: "c6 = ?", + bind: uint(6), + expectValue: uint8(6), + }, + + //usmallint + { + name: "usmallint_int8", + field: "c7", + where: "c7 = ?", + bind: int8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_iny16", + field: "c7", + where: "c7 = ?", + bind: int16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int32", + field: "c7", + where: "c7 = ?", + bind: int32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int64", + field: "c7", + where: "c7 = ?", + bind: int64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint8", + field: "c7", + where: "c7 = ?", + bind: uint8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint16", + field: "c7", + where: "c7 = ?", + bind: uint16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint32", + field: "c7", + where: "c7 = ?", + bind: uint32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint64", + field: "c7", + where: "c7 = ?", + bind: uint64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float32", + field: "c7", + where: "c7 = ?", + bind: float32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float64", + field: "c7", + where: "c7 = ?", + bind: float64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int", + field: "c7", + where: "c7 = ?", + bind: int(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint", + field: "c7", + where: "c7 = ?", + bind: uint(7), + expectValue: uint16(7), + }, + + //uint + { + name: "uint_int8", + field: "c8", + where: "c8 = ?", + bind: int8(8), + expectValue: uint32(8), + }, + { + name: "uint_iny16", + field: "c8", + where: "c8 = ?", + bind: int16(8), + expectValue: uint32(8), + }, + { + name: "uint_int32", + field: "c8", + where: "c8 = ?", + bind: int32(8), + expectValue: uint32(8), + }, + { + name: "uint_int64", + field: "c8", + where: "c8 = ?", + bind: int64(8), + expectValue: uint32(8), + }, + { + name: "uint_uint8", + field: "c8", + where: "c8 = ?", + bind: uint8(8), + expectValue: uint32(8), + }, + { + name: "uint_uint16", + field: "c8", + where: "c8 = ?", + bind: uint16(8), + expectValue: uint32(8), + }, + { + name: "uint_uint32", + field: "c8", + where: "c8 = ?", + bind: uint32(8), + expectValue: uint32(8), + }, + { + name: "uint_uint64", + field: "c8", + where: "c8 = ?", + bind: uint64(8), + expectValue: uint32(8), + }, + { + name: "uint_float32", + field: "c8", + where: "c8 = ?", + bind: float32(8), + expectValue: uint32(8), + }, + { + name: "uint_float64", + field: "c8", + where: "c8 = ?", + bind: float64(8), + expectValue: uint32(8), + }, + { + name: "uint_int", + field: "c8", + where: "c8 = ?", + bind: int(8), + expectValue: uint32(8), + }, + { + name: "uint_uint", + field: "c8", + where: "c8 = ?", + bind: uint(8), + expectValue: uint32(8), + }, + + //ubigint + { + name: "ubigint_int8", + field: "c9", + where: "c9 = ?", + bind: int8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_iny16", + field: "c9", + where: "c9 = ?", + bind: int16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int32", + field: "c9", + where: "c9 = ?", + bind: int32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int64", + field: "c9", + where: "c9 = ?", + bind: int64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint8", + field: "c9", + where: "c9 = ?", + bind: uint8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint16", + field: "c9", + where: "c9 = ?", + bind: uint16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint32", + field: "c9", + where: "c9 = ?", + bind: uint32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint64", + field: "c9", + where: "c9 = ?", + bind: uint64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float32", + field: "c9", + where: "c9 = ?", + bind: float32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float64", + field: "c9", + where: "c9 = ?", + bind: float64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int", + field: "c9", + where: "c9 = ?", + bind: int(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint", + field: "c9", + where: "c9 = ?", + bind: uint(9), + expectValue: uint64(9), + }, + + //float + { + name: "float_int8", + field: "c10", + where: "c10 = ?", + bind: int8(10), + expectValue: float32(10), + }, + { + name: "float_iny16", + field: "c10", + where: "c10 = ?", + bind: int16(10), + expectValue: float32(10), + }, + { + name: "float_int32", + field: "c10", + where: "c10 = ?", + bind: int32(10), + expectValue: float32(10), + }, + { + name: "float_int64", + field: "c10", + where: "c10 = ?", + bind: int64(10), + expectValue: float32(10), + }, + { + name: "float_uint8", + field: "c10", + where: "c10 = ?", + bind: uint8(10), + expectValue: float32(10), + }, + { + name: "float_uint16", + field: "c10", + where: "c10 = ?", + bind: uint16(10), + expectValue: float32(10), + }, + { + name: "float_uint32", + field: "c10", + where: "c10 = ?", + bind: uint32(10), + expectValue: float32(10), + }, + { + name: "float_uint64", + field: "c10", + where: "c10 = ?", + bind: uint64(10), + expectValue: float32(10), + }, + { + name: "float_float32", + field: "c10", + where: "c10 = ?", + bind: float32(10), + expectValue: float32(10), + }, + { + name: "float_float64", + field: "c10", + where: "c10 = ?", + bind: float64(10), + expectValue: float32(10), + }, + { + name: "float_int", + field: "c10", + where: "c10 = ?", + bind: int(10), + expectValue: float32(10), + }, + { + name: "float_uint", + field: "c10", + where: "c10 = ?", + bind: uint(10), + expectValue: float32(10), + }, + + //double + { + name: "double_int8", + field: "c11", + where: "c11 = ?", + bind: int8(11), + expectValue: float64(11), + }, + { + name: "double_iny16", + field: "c11", + where: "c11 = ?", + bind: int16(11), + expectValue: float64(11), + }, + { + name: "double_int32", + field: "c11", + where: "c11 = ?", + bind: int32(11), + expectValue: float64(11), + }, + { + name: "double_int64", + field: "c11", + where: "c11 = ?", + bind: int64(11), + expectValue: float64(11), + }, + { + name: "double_uint8", + field: "c11", + where: "c11 = ?", + bind: uint8(11), + expectValue: float64(11), + }, + { + name: "double_uint16", + field: "c11", + where: "c11 = ?", + bind: uint16(11), + expectValue: float64(11), + }, + { + name: "double_uint32", + field: "c11", + where: "c11 = ?", + bind: uint32(11), + expectValue: float64(11), + }, + { + name: "double_uint64", + field: "c11", + where: "c11 = ?", + bind: uint64(11), + expectValue: float64(11), + }, + { + name: "double_float32", + field: "c11", + where: "c11 = ?", + bind: float32(11), + expectValue: float64(11), + }, + { + name: "double_float64", + field: "c11", + where: "c11 = ?", + bind: float64(11), + expectValue: float64(11), + }, + { + name: "double_int", + field: "c11", + where: "c11 = ?", + bind: int(11), + expectValue: float64(11), + }, + { + name: "double_uint", + field: "c11", + where: "c11 = ?", + bind: uint(11), + expectValue: float64(11), + }, + + // binary + { + name: "binary_string", + field: "c12", + where: "c12 = ?", + bind: "binary", + expectValue: "binary", + }, + { + name: "binary_bytes", + field: "c12", + where: "c12 = ?", + bind: []byte("binary"), + expectValue: "binary", + }, + { + name: "binary_string_like", + field: "c12", + where: "c12 like ?", + bind: "bin%", + expectValue: "binary", + }, + + // nchar + { + name: "nchar_string", + field: "c13", + where: "c13 = ?", + bind: "nchar", + expectValue: "nchar", + }, + { + name: "nchar_bytes", + field: "c13", + where: "c13 = ?", + bind: []byte("nchar"), + expectValue: "nchar", + }, + { + name: "nchar_string", + field: "c13", + where: "c13 like ?", + bind: "nch%", + expectValue: "nchar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) + + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + rows, err := stmt.Query(tt.bind) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + tts, err := rows.ColumnTypes() + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + var data []driver.Value + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if tt.expectNoValue { + if len(data) > 0 { + t.Errorf("expect no value got %#v", data) + return + } + return + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} diff --git a/taosWS/connection_test.go b/taosWS/connection_test.go index 200aefe..6aee030 100644 --- a/taosWS/connection_test.go +++ b/taosWS/connection_test.go @@ -6,6 +6,9 @@ import ( "github.com/stretchr/testify/assert" ) +// @author: xftan +// @date: 2023/10/13 11:22 +// @description: test format bytes func Test_formatBytes(t *testing.T) { type args struct { bs []byte diff --git a/taosWS/connector_test.go b/taosWS/connector_test.go index 9ea2176..a40a251 100644 --- a/taosWS/connector_test.go +++ b/taosWS/connector_test.go @@ -11,6 +11,9 @@ import ( "github.com/taosdata/driver-go/v3/types" ) +// @author: xftan +// @date: 2023/10/13 11:22 +// @description: test all type query func TestAllTypeQuery(t *testing.T) { rand.Seed(time.Now().UnixNano()) db, err := sql.Open("taosWS", dataSourceName) @@ -141,6 +144,9 @@ func TestAllTypeQuery(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:22 +// @description: test null value func TestAllTypeQueryNull(t *testing.T) { rand.Seed(time.Now().UnixNano()) db, err := sql.Open("taosWS", dataSourceName) @@ -257,6 +263,9 @@ func TestAllTypeQueryNull(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:24 +// @description: test compression func TestAllTypeQueryCompression(t *testing.T) { rand.Seed(time.Now().UnixNano()) db, err := sql.Open("taosWS", dataSourceNameWithCompression) @@ -386,6 +395,9 @@ func TestAllTypeQueryCompression(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:24 +// @description: test all type query without json func TestAllTypeQueryWithoutJson(t *testing.T) { rand.Seed(time.Now().UnixNano()) db, err := sql.Open("taosWS", dataSourceName) @@ -512,6 +524,9 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:24 +// @description: test all type query with null without json func TestAllTypeQueryNullWithoutJson(t *testing.T) { rand.Seed(time.Now().UnixNano()) db, err := sql.Open("taosWS", dataSourceName) @@ -624,6 +639,9 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:24 +// @description: test query func TestBatch(t *testing.T) { now := time.Now() tests := []struct { diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index da3543d..9e12d7a 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -13,6 +13,7 @@ import ( ) // Ensure that all the driver interfaces are implemented + func TestMain(m *testing.M) { m.Run() db, err := sql.Open(driverName, dataSourceName) @@ -112,6 +113,9 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows, return } +// @author: xftan +// @date: 2023/10/13 11:25 +// @description: test empty query func TestEmptyQuery(t *testing.T) { runTests(t, func(dbt *DBTest) { // just a comment, no query @@ -123,6 +127,9 @@ func TestEmptyQuery(t *testing.T) { }) } +// @author: xftan +// @date: 2023/10/13 11:25 +// @description: test error query func TestErrorQuery(t *testing.T) { runTests(t, func(dbt *DBTest) { // just a comment, no query @@ -191,6 +198,9 @@ var ( } ) +// @author: xftan +// @date: 2023/10/13 11:25 +// @description: test select and query func TestAny(t *testing.T) { runTests(t, func(dbt *DBTest) { now := time.Now() @@ -215,6 +225,9 @@ func TestAny(t *testing.T) { }) } +// @author: xftan +// @date: 2023/10/13 11:26 +// @description: test chinese func TestChinese(t *testing.T) { db, err := sql.Open(driverName, dataSourceName) if err != nil { @@ -223,33 +236,33 @@ func TestChinese(t *testing.T) { } defer db.Close() defer func() { - _, err = db.Exec("drop database if exists test_chinese") + _, err = db.Exec("drop database if exists test_chinese_ws") if err != nil { t.Error(err) return } }() - _, err = db.Exec("create database if not exists test_chinese") + _, err = db.Exec("create database if not exists test_chinese_ws") if err != nil { t.Error(err) return } - _, err = db.Exec("drop table if exists test_chinese.chinese") + _, err = db.Exec("drop table if exists test_chinese_ws.chinese") if err != nil { t.Error(err) return } - _, err = db.Exec("create table if not exists test_chinese.chinese(ts timestamp,v nchar(32))") + _, err = db.Exec("create table if not exists test_chinese_ws.chinese(ts timestamp,v nchar(32))") if err != nil { t.Error(err) return } - _, err = db.Exec(`INSERT INTO test_chinese.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") + _, err = db.Exec(`INSERT INTO test_chinese_ws.chinese (ts, v) VALUES (?, ?)`, "1641010332000", "'阴天'") if err != nil { t.Error(err) return } - rows, err := db.Query("select * from test_chinese.chinese") + rows, err := db.Query("select * from test_chinese_ws.chinese") if err != nil { t.Error(err) return diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index 356be70..edfd013 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -7,6 +7,9 @@ import ( "github.com/stretchr/testify/assert" ) +// @author: xftan +// @date: 2023/10/13 11:26 +// @description: test parse dsn func TestParseDsn(t *testing.T) { tests := []struct { dsn string diff --git a/taosWS/error_test.go b/taosWS/error_test.go index c364f6d..5e165f2 100644 --- a/taosWS/error_test.go +++ b/taosWS/error_test.go @@ -8,6 +8,9 @@ import ( "github.com/stretchr/testify/assert" ) +// @author: xftan +// @date: 2023/10/13 11:26 +// @description: test bad conn error func TestBadConnError(t *testing.T) { nothingErr := errors.New("error") err := NewBadConnError(nothingErr) diff --git a/taosWS/rows.go b/taosWS/rows.go index b75f4e2..b462f4e 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -10,6 +10,7 @@ import ( "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/common/pointer" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -151,7 +152,7 @@ func (rs *rows) fetchBlock() error { return err } rs.block = respBytes - rs.blockPtr = unsafe.Pointer(*(*uintptr)(unsafe.Pointer(&rs.block)) + uintptr(16)) + rs.blockPtr = pointer.AddUintptr(unsafe.Pointer(&rs.block[0]), 16) rs.blockOffset = 0 return nil } diff --git a/types/taostype.go b/types/taostype.go index 50a5da8..f1bbcc2 100644 --- a/types/taostype.go +++ b/types/taostype.go @@ -18,12 +18,14 @@ type ( TaosFloat float32 TaosDouble float64 TaosBinary []byte + TaosVarBinary []byte TaosNchar string TaosTimestamp struct { T time.Time Precision int } - TaosJson []byte + TaosJson []byte + TaosGeometry []byte ) var ( @@ -39,9 +41,11 @@ var ( TaosFloatType = reflect.TypeOf(TaosFloat(0)) TaosDoubleType = reflect.TypeOf(TaosDouble(0)) TaosBinaryType = reflect.TypeOf(TaosBinary(nil)) + TaosVarBinaryType = reflect.TypeOf(TaosVarBinary(nil)) TaosNcharType = reflect.TypeOf(TaosNchar("")) TaosTimestampType = reflect.TypeOf(TaosTimestamp{}) TaosJsonType = reflect.TypeOf(TaosJson("")) + TaosGeometryType = reflect.TypeOf(TaosGeometry(nil)) ) type ColumnType struct { diff --git a/wrapper/block.go b/wrapper/block.go index 30ef5a6..eacf894 100644 --- a/wrapper/block.go +++ b/wrapper/block.go @@ -33,3 +33,17 @@ func TaosWriteRawBlockWithFields(conn unsafe.Pointer, numOfRows int, pData unsaf defer C.free(unsafe.Pointer(cStr)) return int(C.taos_write_raw_block_with_fields(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (*C.struct_taosField)(fields), (C.int)(numFields))) } + +// DLL_EXPORT int taos_write_raw_block_with_reqid(TAOS *taos, int numOfRows, char *pData, const char *tbname, int64_t reqid); +func TaosWriteRawBlockWithReqID(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, reqID int64) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_reqid(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (C.int64_t)(reqID))) +} + +// DLL_EXPORT int taos_write_raw_block_with_fields_with_reqid(TAOS *taos, int rows, char *pData, const char *tbname,TAOS_FIELD *fields, int numFields, int64_t reqid); +func TaosWriteRawBlockWithFieldsWithReqID(conn unsafe.Pointer, numOfRows int, pData unsafe.Pointer, tableName string, fields unsafe.Pointer, numFields int, reqID int64) int { + cStr := C.CString(tableName) + defer C.free(unsafe.Pointer(cStr)) + return int(C.taos_write_raw_block_with_fields_with_reqid(conn, (C.int)(numOfRows), (*C.char)(pData), cStr, (*C.struct_taosField)(fields), (C.int)(numFields), (C.int64_t)(reqID))) +} diff --git a/wrapper/block_test.go b/wrapper/block_test.go index e3b7b63..22f6239 100644 --- a/wrapper/block_test.go +++ b/wrapper/block_test.go @@ -12,6 +12,9 @@ import ( "github.com/taosdata/driver-go/v3/errors" ) +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test read block func TestReadBlock(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -178,6 +181,9 @@ func TestReadBlock(t *testing.T) { assert.Equal(t, []byte(`{"a":1}`), row3[14].([]byte)) } +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test write raw block func TestTaosWriteRawBlock(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -364,6 +370,9 @@ func TestTaosWriteRawBlock(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test write raw block with fields func TestTaosWriteRawBlockWithFields(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -381,17 +390,17 @@ func TestTaosWriteRawBlockWithFields(t *testing.T) { return } TaosFreeResult(res) - //defer func() { - // res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields") - // code = TaosError(res) - // if code != 0 { - // errStr := TaosErrorStr(res) - // TaosFreeResult(res) - // t.Error(errors.NewError(code, errStr)) - // return - // } - // TaosFreeResult(res) - //}() + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() res = TaosQuery(conn, "create database test_write_block_raw_fields") code = TaosError(res) if code != 0 { @@ -542,3 +551,374 @@ func TestTaosWriteRawBlockWithFields(t *testing.T) { assert.Nil(t, row2[i]) } } + +// @author: xftan +// @date: 2023/11/17 9:39 +// @description: test write raw block with reqid +func TestTaosWriteRawBlockWithReqID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw_with_reqid") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw_with_reqid.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw_with_reqid.t0 using test_write_block_raw_with_reqid.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw_with_reqid.t1 using test_write_block_raw_with_reqid.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw_with_reqid" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_with_reqid.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + + errCode = TaosWriteRawBlockWithReqID(conn, blockSize, block, "t1", 1) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_with_reqid.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + assert.Equal(t, int8(1), row1[2].(int8)) + assert.Equal(t, int16(1), row1[3].(int16)) + assert.Equal(t, int32(1), row1[4].(int32)) + assert.Equal(t, int64(1), row1[5].(int64)) + assert.Equal(t, uint8(1), row1[6].(uint8)) + assert.Equal(t, uint16(1), row1[7].(uint16)) + assert.Equal(t, uint32(1), row1[8].(uint32)) + assert.Equal(t, uint64(1), row1[9].(uint64)) + assert.Equal(t, float32(1), row1[10].(float32)) + assert.Equal(t, float64(1), row1[11].(float64)) + assert.Equal(t, "test_binary", row1[12].(string)) + assert.Equal(t, "test_nchar", row1[13].(string)) + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} + +// @author: xftan +// @date: 2023/11/17 9:37 +// @description: test write raw block with fields and reqid +func TestTaosWriteRawBlockWithFieldsWithReqID(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + res := TaosQuery(conn, "drop database if exists test_write_block_raw_fields_with_reqid") + code := TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + defer func() { + res = TaosQuery(conn, "drop database if exists test_write_block_raw_fields_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + }() + res = TaosQuery(conn, "create database test_write_block_raw_fields_with_reqid") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + res = TaosQuery(conn, "create table if not exists test_write_block_raw_fields_with_reqid.all_type (ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20)"+ + ") tags (info json)") + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + now := time.Now() + after1s := now.Add(time.Second) + sql := fmt.Sprintf("insert into test_write_block_raw_fields_with_reqid.t0 using test_write_block_raw_fields_with_reqid.all_type tags('{\"a\":1}') values('%s',1,1,1,1,1,1,1,1,1,1,1,'test_binary','test_nchar')('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", now.Format(time.RFC3339Nano), after1s.Format(time.RFC3339Nano)) + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "create table test_write_block_raw_fields_with_reqid.t1 using test_write_block_raw_fields_with_reqid.all_type tags('{\"a\":2}')" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "use test_write_block_raw_fields_with_reqid" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + TaosFreeResult(res) + + sql = "select ts,c1 from test_write_block_raw_fields_with_reqid.t0" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + fieldsCount := TaosNumFields(res) + fields := TaosFetchFields(res) + + errCode = TaosWriteRawBlockWithFieldsWithReqID(conn, blockSize, block, "t1", fields, fieldsCount, 1) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(nil) + err := errors.NewError(errCode, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + } + TaosFreeResult(res) + + sql = "select * from test_write_block_raw_fields_with_reqid.t1" + res = TaosQuery(conn, sql) + code = TaosError(res) + if code != 0 { + errStr := TaosErrorStr(res) + TaosFreeResult(res) + t.Error(errors.NewError(code, errStr)) + return + } + fileCount := TaosNumFields(res) + rh, err := ReadColumn(res, fileCount) + if err != nil { + t.Error(err) + return + } + precision := TaosResultPrecision(res) + var data [][]driver.Value + for { + blockSize, errCode, block := TaosFetchRawBlock(res) + if errCode != int(errors.SUCCESS) { + errStr := TaosErrorStr(res) + err := errors.NewError(code, errStr) + t.Error(err) + TaosFreeResult(res) + return + } + if blockSize == 0 { + break + } + d := parser.ReadBlock(block, blockSize, rh.ColTypes, precision) + data = append(data, d...) + } + TaosFreeResult(res) + + assert.Equal(t, 2, len(data)) + row1 := data[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1].(bool)) + for i := 2; i < 14; i++ { + assert.Nil(t, row1[i]) + } + row2 := data[1] + assert.Equal(t, after1s.UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } +} diff --git a/wrapper/handler/handlerpool_test.go b/wrapper/handler/handlerpool_test.go index e53017a..d2eb581 100644 --- a/wrapper/handler/handlerpool_test.go +++ b/wrapper/handler/handlerpool_test.go @@ -59,11 +59,17 @@ func TestHandlerPool_Get(t *testing.T) { pool.Put(h2) } +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test caller query func TestCaller_QueryCall(t *testing.T) { caller := NewCaller() caller.QueryCall(nil, 0) } +// @author: xftan +// @date: 2023/10/13 11:27 +// @description: test caller fetch func TestCaller_FetchCall(t *testing.T) { caller := NewCaller() caller.FetchCall(nil, 0) diff --git a/wrapper/notify.go b/wrapper/notify.go new file mode 100644 index 0000000..48480d9 --- /dev/null +++ b/wrapper/notify.go @@ -0,0 +1,24 @@ +package wrapper + +import "C" + +/* +#include +#include +#include +#include +extern void NotifyCallback(void *param, void *ext, int type); +int taos_set_notify_cb_wrapper(TAOS *taos, void *param, int type){ + return taos_set_notify_cb(taos,NotifyCallback,param,type); +}; +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +func TaosSetNotifyCB(taosConnect unsafe.Pointer, caller cgo.Handle, notifyType int) int32 { + return int32(C.taos_set_notify_cb_wrapper(taosConnect, caller.Pointer(), (C.int)(notifyType))) +} diff --git a/wrapper/notify_test.go b/wrapper/notify_test.go new file mode 100644 index 0000000..fa8cd06 --- /dev/null +++ b/wrapper/notify_test.go @@ -0,0 +1,97 @@ +package wrapper + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test notify callback +func TestNotify(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn) + defer exec(conn, "drop user t_notify") + exec(conn, "drop user t_notify") + err = exec(conn, "create user t_notify pass 'notify'") + assert.NoError(t, err) + conn2, err := TaosConnect("", "t_notify", "notify", "", 0) + if err != nil { + t.Error(err) + return + } + + defer TaosClose(conn2) + notify := make(chan int32, 1) + handler := cgo.NewHandle(notify) + errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_PASSVER) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + err = exec(conn, "alter user t_notify pass 'test'") + assert.NoError(t, err) + timeout, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + now := time.Now() + select { + case version := <-notify: + fmt.Println(time.Now().Sub(now)) + t.Log(version) + case <-timeout.Done(): + t.Error("wait for notify callback timeout") + } + { + notify := make(chan int64, 1) + handler := cgo.NewHandle(notify) + errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_WHITELIST_VER) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + err = exec(conn, "ALTER USER t_notify ADD HOST '192.168.1.98/0','192.168.1.98/32'") + assert.NoError(t, err) + timeout, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + now := time.Now() + select { + case version := <-notify: + fmt.Println(time.Now().Sub(now)) + t.Log(version) + case <-timeout.Done(): + t.Error("wait for notify callback timeout") + } + } + { + notify := make(chan struct{}, 1) + handler := cgo.NewHandle(notify) + errCode := TaosSetNotifyCB(conn2, handler, common.TAOS_NOTIFY_USER_DROPPED) + if errCode != 0 { + errStr := TaosErrorStr(nil) + t.Error(errCode, errStr) + } + err = exec(conn, "drop USER t_notify") + assert.NoError(t, err) + timeout, cancel = context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + now := time.Now() + select { + case _ = <-notify: + fmt.Println(time.Now().Sub(now)) + t.Log("user dropped") + case <-timeout.Done(): + t.Error("wait for notify callback timeout") + } + } +} diff --git a/wrapper/notifycb.go b/wrapper/notifycb.go new file mode 100644 index 0000000..d933345 --- /dev/null +++ b/wrapper/notifycb.go @@ -0,0 +1,36 @@ +package wrapper + +/* +#include +#include +#include +#include +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +//export NotifyCallback +func NotifyCallback(p unsafe.Pointer, ext unsafe.Pointer, notifyType C.int) { + defer func() { + // channel may be closed + recover() + }() + switch int(notifyType) { + case common.TAOS_NOTIFY_PASSVER: + version := int32(*(*C.int32_t)(ext)) + c := (*(*cgo.Handle)(p)).Value().(chan int32) + c <- version + case common.TAOS_NOTIFY_WHITELIST_VER: + version := int64(*(*C.int64_t)(ext)) + c := (*(*cgo.Handle)(p)).Value().(chan int64) + c <- version + case common.TAOS_NOTIFY_USER_DROPPED: + c := (*(*cgo.Handle)(p)).Value().(chan struct{}) + c <- struct{}{} + } +} diff --git a/wrapper/row.go b/wrapper/row.go index 6ee6508..38422f8 100644 --- a/wrapper/row.go +++ b/wrapper/row.go @@ -9,6 +9,7 @@ import ( "unsafe" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/pointer" ) const ( @@ -18,7 +19,8 @@ const ( type FormatTimeFunc func(ts int64, precision int) driver.Value func FetchRow(row unsafe.Pointer, offset int, colType uint8, length int, arg ...interface{}) driver.Value { - p := unsafe.Pointer(*(*uintptr)(unsafe.Pointer(uintptr(row) + uintptr(offset)*PointerSize))) + base := *(**C.void)(pointer.AddUintptr(row, uintptr(offset)*PointerSize)) + p := unsafe.Pointer(base) if p == nil { return nil } @@ -52,7 +54,7 @@ func FetchRow(row unsafe.Pointer, offset int, colType uint8, length int, arg ... case C.TSDB_DATA_TYPE_BINARY, C.TSDB_DATA_TYPE_NCHAR: data := make([]byte, length) for i := 0; i < length; i++ { - data[i] = *((*byte)(unsafe.Pointer(uintptr(p) + uintptr(i)))) + data[i] = *((*byte)(pointer.AddUintptr(p, uintptr(i)))) } return string(data) case C.TSDB_DATA_TYPE_TIMESTAMP: @@ -63,10 +65,10 @@ func FetchRow(row unsafe.Pointer, offset int, colType uint8, length int, arg ... } else { panic("convertTime error") } - case C.TSDB_DATA_TYPE_JSON: + case C.TSDB_DATA_TYPE_JSON, C.TSDB_DATA_TYPE_VARBINARY, C.TSDB_DATA_TYPE_GEOMETRY: data := make([]byte, length) for i := 0; i < length; i++ { - data[i] = *((*byte)(unsafe.Pointer(uintptr(p) + uintptr(i)))) + data[i] = *((*byte)(pointer.AddUintptr(p, uintptr(i)))) } return data default: diff --git a/wrapper/row_test.go b/wrapper/row_test.go index b748c6c..92f9425 100644 --- a/wrapper/row_test.go +++ b/wrapper/row_test.go @@ -22,7 +22,7 @@ func TestFetchRowJSON(t *testing.T) { defer TaosClose(conn) defer func() { - res := TaosQuery(conn, "drop database if exists test_json") + res := TaosQuery(conn, "drop database if exists test_json_wrapper") code := TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -32,7 +32,7 @@ func TestFetchRowJSON(t *testing.T) { } TaosFreeResult(res) }() - res := TaosQuery(conn, "create database if not exists test_json") + res := TaosQuery(conn, "create database if not exists test_json_wrapper") code := TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -42,7 +42,7 @@ func TestFetchRowJSON(t *testing.T) { } TaosFreeResult(res) defer func() { - res := TaosQuery(conn, "drop database if exists test_json") + res := TaosQuery(conn, "drop database if exists test_json_wrapper") code := TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -55,7 +55,7 @@ func TestFetchRowJSON(t *testing.T) { } TaosFreeResult(res) }() - res = TaosQuery(conn, "drop table if exists test_json.tjsonr") + res = TaosQuery(conn, "drop table if exists test_json_wrapper.tjsonr") code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -67,7 +67,7 @@ func TestFetchRowJSON(t *testing.T) { return } TaosFreeResult(res) - res = TaosQuery(conn, "create stable if not exists test_json.tjsonr(ts timestamp,v int )tags(t json)") + res = TaosQuery(conn, "create stable if not exists test_json_wrapper.tjsonr(ts timestamp,v int )tags(t json)") code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -79,7 +79,7 @@ func TestFetchRowJSON(t *testing.T) { return } TaosFreeResult(res) - res = TaosQuery(conn, `insert into test_json.tjr_1 using test_json.tjsonr tags('{"a":1,"b":"b"}')values (now,1)`) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_1 using test_json_wrapper.tjsonr tags('{"a":1,"b":"b"}')values (now,1)`) code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -91,7 +91,7 @@ func TestFetchRowJSON(t *testing.T) { return } TaosFreeResult(res) - res = TaosQuery(conn, `insert into test_json.tjr_2 using test_json.tjsonr tags('{"a":1,"c":"c"}')values (now+1s,1)`) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_2 using test_json_wrapper.tjsonr tags('{"a":1,"c":"c"}')values (now+1s,1)`) code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -103,7 +103,7 @@ func TestFetchRowJSON(t *testing.T) { return } TaosFreeResult(res) - res = TaosQuery(conn, `insert into test_json.tjr_3 using test_json.tjsonr tags('null')values (now+2s,1)`) + res = TaosQuery(conn, `insert into test_json_wrapper.tjr_3 using test_json_wrapper.tjsonr tags('null')values (now+2s,1)`) code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -116,7 +116,7 @@ func TestFetchRowJSON(t *testing.T) { } TaosFreeResult(res) - res = TaosQuery(conn, `select * from test_json.tjsonr order by ts`) + res = TaosQuery(conn, `select * from test_json_wrapper.tjsonr order by ts`) code = TaosError(res) if code != 0 { errStr := TaosErrorStr(res) @@ -482,6 +482,9 @@ func TestFetchRowNchar(t *testing.T) { assert.Empty(t, names) } +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test fetch row all type func TestFetchRowAllType(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -536,7 +539,9 @@ func TestFetchRowAllType(t *testing.T) { "c10 float,"+ "c11 double,"+ "c12 binary(20),"+ - "c13 nchar(20)"+ + "c13 nchar(20),"+ + "c14 varbinary(20),"+ + "c15 geometry(100)"+ ")"+ "tags(t json)", db)) code = TaosError(res) @@ -560,7 +565,7 @@ func TestFetchRowAllType(t *testing.T) { } TaosFreeResult(res) now := time.Now() - res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb1 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar');", db, now.Format(time.RFC3339Nano))) + res = TaosQuery(conn, fmt.Sprintf("insert into %s.tb1 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar','varbinary','POINT(100 100)');", db, now.Format(time.RFC3339Nano))) code = TaosError(res) if code != int(errors.SUCCESS) { errStr := TaosErrorStr(res) @@ -617,5 +622,7 @@ func TestFetchRowAllType(t *testing.T) { assert.Equal(t, float64(11), result[11].(float64)) assert.Equal(t, "binary", result[12].(string)) assert.Equal(t, "nchar", result[13].(string)) - assert.Equal(t, []byte(`{"a":1}`), result[14].([]byte)) + assert.Equal(t, []byte("varbinary"), result[14].([]byte)) + assert.Equal(t, []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, result[15].([]byte)) + assert.Equal(t, []byte(`{"a":1}`), result[16].([]byte)) } diff --git a/wrapper/schemaless_test.go b/wrapper/schemaless_test.go index 6edbbdd..db07466 100644 --- a/wrapper/schemaless_test.go +++ b/wrapper/schemaless_test.go @@ -204,6 +204,9 @@ func TestSchemalessInfluxDB(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:28 +// @description: test schemaless insert with opentsdb telnet line protocol func TestSchemalessRawTelnet(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -243,6 +246,9 @@ func TestSchemalessRawTelnet(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with opentsdb telnet line protocol func TestSchemalessRawInfluxDB(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -313,6 +319,9 @@ func TestSchemalessRawInfluxDB(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert raw with reqid func TestTaosSchemalessInsertRawWithReqID(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -378,6 +387,9 @@ func TestTaosSchemalessInsertRawWithReqID(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with reqid func TestTaosSchemalessInsertWithReqID(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -436,6 +448,9 @@ func TestTaosSchemalessInsertWithReqID(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:29 +// @description: test schemaless insert with ttl func TestTaosSchemalessInsertTTL(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -488,6 +503,9 @@ func TestTaosSchemalessInsertTTL(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert with ttl and reqid func TestTaosSchemalessInsertTTLWithReqID(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -545,6 +563,9 @@ func TestTaosSchemalessInsertTTLWithReqID(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert raw with ttl func TestTaosSchemalessInsertRawTTL(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) @@ -596,6 +617,9 @@ func TestTaosSchemalessInsertRawTTL(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test schemaless insert raw with ttl and reqid func TestTaosSchemalessInsertRawTTLWithReqID(t *testing.T) { conn := prepareEnv() defer wrapper.TaosClose(conn) diff --git a/wrapper/stmt.go b/wrapper/stmt.go index cbf5e30..e2bb2c0 100644 --- a/wrapper/stmt.go +++ b/wrapper/stmt.go @@ -11,10 +11,11 @@ import ( "bytes" "database/sql/driver" "errors" - "fmt" "unsafe" "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/stmt" + taosError "github.com/taosdata/driver-go/v3/errors" taosTypes "github.com/taosdata/driver-go/v3/types" ) @@ -234,6 +235,28 @@ func generateTaosBindList(params []driver.Value) ([]C.TAOS_MULTI_BIND, []unsafe. *(bind.length) = C.int32_t(clen) needFreePointer = append(needFreePointer, p) bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosVarBinary: + bind.buffer_type = C.TSDB_DATA_TYPE_VARBINARY + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) + case taosTypes.TaosGeometry: + bind.buffer_type = C.TSDB_DATA_TYPE_GEOMETRY + cbuf := C.CString(string(value)) + needFreePointer = append(needFreePointer, unsafe.Pointer(cbuf)) + bind.buffer = unsafe.Pointer(cbuf) + clen := int32(len(value)) + p := C.malloc(C.size_t(unsafe.Sizeof(clen))) + bind.length = (*C.int32_t)(p) + *(bind.length) = C.int32_t(clen) + needFreePointer = append(needFreePointer, p) + bind.buffer_length = C.uintptr_t(clen) case taosTypes.TaosNchar: bind.buffer_type = C.TSDB_DATA_TYPE_NCHAR p := unsafe.Pointer(C.CString(string(value))) @@ -338,6 +361,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin } else { *(*C.int8_t)(current) = C.int8_t(0) } + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) } } case taosTypes.TaosTinyintType: @@ -354,6 +380,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosTinyint) current := unsafe.Pointer(uintptr(p) + uintptr(i)) *(*C.int8_t)(current) = C.int8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) } } case taosTypes.TaosSmallintType: @@ -370,6 +399,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosSmallint) current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) *(*C.int16_t)(current) = C.int16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) } } case taosTypes.TaosIntType: @@ -386,6 +418,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosInt) current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.int32_t)(current) = C.int32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) } } case taosTypes.TaosBigintType: @@ -402,6 +437,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosBigint) current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.int64_t)(current) = C.int64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) } } case taosTypes.TaosUTinyintType: @@ -418,6 +456,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosUTinyint) current := unsafe.Pointer(uintptr(p) + uintptr(i)) *(*C.uint8_t)(current) = C.uint8_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(1) } } case taosTypes.TaosUSmallintType: @@ -434,6 +475,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosUSmallint) current := unsafe.Pointer(uintptr(p) + uintptr(2*i)) *(*C.uint16_t)(current) = C.uint16_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(2) } } case taosTypes.TaosUIntType: @@ -450,6 +494,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosUInt) current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.uint32_t)(current) = C.uint32_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) } } case taosTypes.TaosUBigintType: @@ -466,6 +513,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosUBigint) current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.uint64_t)(current) = C.uint64_t(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) } } case taosTypes.TaosFloatType: @@ -482,6 +532,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosFloat) current := unsafe.Pointer(uintptr(p) + uintptr(4*i)) *(*C.float)(current) = C.float(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(4) } } case taosTypes.TaosDoubleType: @@ -498,6 +551,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin value := rowData.(taosTypes.TaosDouble) current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.double)(current) = C.double(value) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) } } case taosTypes.TaosBinaryType: @@ -518,6 +574,42 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin *(*C.int32_t)(l) = C.int32_t(len(value)) } } + case taosTypes.TaosVarBinaryType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_VARBINARY + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosVarBinary) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } + case taosTypes.TaosGeometryType: + p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) + bind.buffer_type = C.TSDB_DATA_TYPE_GEOMETRY + bind.buffer_length = C.uintptr_t(columnType.MaxLen) + for i, rowData := range columnData { + currentNull := unsafe.Pointer(uintptr(nullList) + uintptr(i)) + if rowData == nil { + *(*C.char)(currentNull) = C.char(1) + } else { + *(*C.char)(currentNull) = C.char(0) + value := rowData.(taosTypes.TaosGeometry) + for j := 0; j < len(value); j++ { + *(*C.char)(unsafe.Pointer(uintptr(p) + uintptr(columnType.MaxLen*i+j))) = (C.char)(value[j]) + } + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(len(value)) + } + } case taosTypes.TaosNcharType: p = unsafe.Pointer(C.malloc(C.size_t(C.uint(columnType.MaxLen * rowLen)))) bind.buffer_type = C.TSDB_DATA_TYPE_NCHAR @@ -551,6 +643,9 @@ func TaosStmtBindParamBatch(stmt unsafe.Pointer, multiBind [][]driver.Value, bin ts := common.TimeToTimestamp(value.T, value.Precision) current := unsafe.Pointer(uintptr(p) + uintptr(8*i)) *(*C.int64_t)(current) = C.int64_t(ts) + + l := unsafe.Pointer(uintptr(lengthList) + uintptr(4*i)) + *(*C.int32_t)(l) = C.int32_t(8) } } } @@ -586,50 +681,6 @@ func TaosStmtAffectedRowsOnce(stmt unsafe.Pointer) int { //int32_t bytes; //} TAOS_FIELD_E; -type StmtField struct { - Name string - FieldType int8 - Precision uint8 - Scale uint8 - Bytes int32 -} - -func (s *StmtField) GetType() (*taosTypes.ColumnType, error) { - switch s.FieldType { - case common.TSDB_DATA_TYPE_BOOL: - return &taosTypes.ColumnType{Type: taosTypes.TaosBoolType}, nil - case common.TSDB_DATA_TYPE_TINYINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosTinyintType}, nil - case common.TSDB_DATA_TYPE_SMALLINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosSmallintType}, nil - case common.TSDB_DATA_TYPE_INT: - return &taosTypes.ColumnType{Type: taosTypes.TaosIntType}, nil - case common.TSDB_DATA_TYPE_BIGINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosBigintType}, nil - case common.TSDB_DATA_TYPE_UTINYINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosUTinyintType}, nil - case common.TSDB_DATA_TYPE_USMALLINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosUSmallintType}, nil - case common.TSDB_DATA_TYPE_UINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosUIntType}, nil - case common.TSDB_DATA_TYPE_UBIGINT: - return &taosTypes.ColumnType{Type: taosTypes.TaosUBigintType}, nil - case common.TSDB_DATA_TYPE_FLOAT: - return &taosTypes.ColumnType{Type: taosTypes.TaosFloatType}, nil - case common.TSDB_DATA_TYPE_DOUBLE: - return &taosTypes.ColumnType{Type: taosTypes.TaosDoubleType}, nil - case common.TSDB_DATA_TYPE_BINARY: - return &taosTypes.ColumnType{Type: taosTypes.TaosBinaryType}, nil - case common.TSDB_DATA_TYPE_NCHAR: - return &taosTypes.ColumnType{Type: taosTypes.TaosNcharType}, nil - case common.TSDB_DATA_TYPE_TIMESTAMP: - return &taosTypes.ColumnType{Type: taosTypes.TaosTimestampType}, nil - case common.TSDB_DATA_TYPE_JSON: - return &taosTypes.ColumnType{Type: taosTypes.TaosJsonType}, nil - } - return nil, fmt.Errorf("unsupported type: %d, name %s", s.FieldType, s.Name) -} - // TaosStmtGetTagFields DLL_EXPORT int taos_stmt_get_tag_fields(TAOS_STMT *stmt, int* fieldNum, TAOS_FIELD_E** fields); func TaosStmtGetTagFields(stmt unsafe.Pointer) (code, num int, fields unsafe.Pointer) { cNum := unsafe.Pointer(&num) @@ -658,14 +709,14 @@ func TaosStmtGetColFields(stmt unsafe.Pointer) (code, num int, fields unsafe.Poi return code, num, unsafe.Pointer(cField) } -func StmtParseFields(num int, fields unsafe.Pointer) []*StmtField { +func StmtParseFields(num int, fields unsafe.Pointer) []*stmt.StmtField { if num == 0 { return nil } - result := make([]*StmtField, num) + result := make([]*stmt.StmtField, num) buf := bytes.NewBufferString("") for i := 0; i < num; i++ { - r := &StmtField{} + r := &stmt.StmtField{} field := *(*C.TAOS_FIELD_E)(unsafe.Pointer(uintptr(fields) + uintptr(C.sizeof_struct_TAOS_FIELD_E*C.int(i)))) for _, c := range field.name { if c == 0 { @@ -688,3 +739,15 @@ func StmtParseFields(num int, fields unsafe.Pointer) []*StmtField { func TaosStmtReclaimFields(stmt unsafe.Pointer, fields unsafe.Pointer) { C.taos_stmt_reclaim_fields(stmt, (*C.TAOS_FIELD_E)(fields)) } + +// TaosStmtGetParam DLL_EXPORT int taos_stmt_get_param(TAOS_STMT *stmt, int idx, int *type, int *bytes) +func TaosStmtGetParam(stmt unsafe.Pointer, idx int) (dataType int, dataLength int, err error) { + code := C.taos_stmt_get_param(stmt, C.int(idx), (*C.int)(unsafe.Pointer(&dataType)), (*C.int)(unsafe.Pointer(&dataLength))) + if code != 0 { + err = &taosError.TaosError{ + Code: int32(code), + ErrStr: TaosStmtErrStr(stmt), + } + } + return +} diff --git a/wrapper/stmt_test.go b/wrapper/stmt_test.go index a7db700..d924645 100644 --- a/wrapper/stmt_test.go +++ b/wrapper/stmt_test.go @@ -11,6 +11,7 @@ import ( "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/param" "github.com/taosdata/driver-go/v3/common/parser" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" taosError "github.com/taosdata/driver-go/v3/errors" taosTypes "github.com/taosdata/driver-go/v3/types" ) @@ -137,6 +138,26 @@ func TestStmt(t *testing.T) { }}, expectValue: "yes", }, //3 + { + tbType: "ts timestamp, v varbinary(8)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosVarBinary("yes")}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosVarBinaryType, + MaxLen: 3, + }}, + expectValue: []byte("yes"), + }, //3 + { + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: [][]driver.Value{{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}}, {taosTypes.TaosGeometry{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}}, + bindType: []*taosTypes.ColumnType{{Type: taosTypes.TaosTimestampType}, { + Type: taosTypes.TaosGeometryType, + MaxLen: 3, + }}, + expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, //3 { tbType: "ts timestamp, v nchar(8)", pos: "?, ?", @@ -220,10 +241,7 @@ func TestStmt(t *testing.T) { t.Errorf("expect %d got %d", 1, len(result)) return } - if result[0][0] != tc.expectValue { - t.Errorf("expect %v got %v", tc.expectValue, result[0][0]) - return - } + assert.Equal(t, tc.expectValue, result[0][0]) }) } @@ -335,6 +353,18 @@ func TestStmtExec(t *testing.T) { params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosBinary("yes")}, expectValue: "yes", }, //3 + { + tbType: "ts timestamp, v varbinary(8)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosVarBinary("yes")}, + expectValue: []byte("yes"), + }, //3 + { + tbType: "ts timestamp, v geometry(100)", + pos: "?, ?", + params: []driver.Value{taosTypes.TaosTimestamp{T: now, Precision: common.PrecisionMilliSecond}, taosTypes.TaosGeometry{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}}, + expectValue: []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40}, + }, //3 { tbType: "ts timestamp, v nchar(8)", pos: "?, ?", @@ -415,14 +445,14 @@ func TestStmtExec(t *testing.T) { t.Errorf("expect %d got %d", 1, len(result)) return } - if result[0][0] != tc.expectValue { - t.Errorf("expect %v got %v", tc.expectValue, result[0][0]) - return - } + assert.Equal(t, tc.expectValue, result[0][0]) }) } } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test stmt query func TestStmtQuery(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -685,6 +715,9 @@ func StmtQuery(t *testing.T, conn unsafe.Pointer, sql string, params *param.Para return data, nil } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test get field func TestGetFields(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -771,7 +804,7 @@ func TestGetFields(t *testing.T) { defer TaosStmtReclaimFields(stmt, columnsP) columns := StmtParseFields(columnCount, columnsP) tags := StmtParseFields(tagCount, tagsP) - assert.Equal(t, []*StmtField{ + assert.Equal(t, []*stmtCommon.StmtField{ {"ts", 9, 0, 0, 8}, {"c1", 1, 0, 0, 1}, {"c2", 2, 0, 0, 1}, @@ -787,7 +820,7 @@ func TestGetFields(t *testing.T) { {"c12", 8, 0, 0, 22}, {"c13", 10, 0, 0, 82}, }, columns) - assert.Equal(t, []*StmtField{ + assert.Equal(t, []*stmtCommon.StmtField{ {"tts", 9, 0, 0, 8}, {"tc1", 1, 0, 0, 1}, {"tc2", 2, 0, 0, 1}, @@ -805,6 +838,9 @@ func TestGetFields(t *testing.T) { }, tags) } +// @author: xftan +// @date: 2023/10/13 11:30 +// @description: test get fields with common table func TestGetFieldsCommonTable(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -863,7 +899,7 @@ func TestGetFieldsCommonTable(t *testing.T) { } defer TaosStmtReclaimFields(stmt, columnsP) columns := StmtParseFields(columnCount, columnsP) - assert.Equal(t, []*StmtField{ + assert.Equal(t, []*stmtCommon.StmtField{ {"ts", 9, 0, 0, 8}, {"c1", 1, 0, 0, 1}, {"c2", 2, 0, 0, 1}, @@ -892,6 +928,9 @@ func exec(conn unsafe.Pointer, sql string) error { return nil } +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test stmt set tags func TestTaosStmtSetTags(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -1128,3 +1167,46 @@ func TestTaosStmtSetTags(t *testing.T) { assert.Equal(t, int32(102), data[0][2].(int32)) assert.Equal(t, []byte(`{"a":"b"}`), data[0][3].([]byte)) } + +func TestTaosStmtGetParam(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + + err = exec(conn, "drop database if exists test_stmt_get_param") + assert.NoError(t, err) + err = exec(conn, "create database if not exists test_stmt_get_param") + assert.NoError(t, err) + defer exec(conn, "drop database if exists test_stmt_get_param") + + err = exec(conn, + "create table if not exists test_stmt_get_param.stb(ts TIMESTAMP,current float,voltage int,phase float) TAGS (groupid int,location varchar(24))") + assert.NoError(t, err) + + stmt := TaosStmtInit(conn) + assert.NotNilf(t, stmt, "failed to init stmt") + defer TaosStmtClose(stmt) + + code := TaosStmtPrepare(stmt, "insert into test_stmt_get_param.tb_0 using test_stmt_get_param.stb tags(?,?) values (?,?,?,?)") + assert.Equal(t, 0, code, TaosStmtErrStr(stmt)) + + dt, dl, err := TaosStmtGetParam(stmt, 0) // ts + assert.NoError(t, err) + assert.Equal(t, 9, dt) + assert.Equal(t, 8, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 1) // current + assert.NoError(t, err) + assert.Equal(t, 6, dt) + assert.Equal(t, 4, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 2) // voltage + assert.NoError(t, err) + assert.Equal(t, 4, dt) + assert.Equal(t, 4, dl) + + dt, dl, err = TaosStmtGetParam(stmt, 3) // phase + assert.NoError(t, err) + assert.Equal(t, 6, dt) + assert.Equal(t, 4, dl) +} diff --git a/wrapper/taosc.go b/wrapper/taosc.go index 803ed16..2e952a2 100644 --- a/wrapper/taosc.go +++ b/wrapper/taosc.go @@ -33,6 +33,7 @@ import ( "strings" "unsafe" + "github.com/taosdata/driver-go/v3/common/pointer" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper/cgo" ) @@ -251,7 +252,38 @@ func TaosGetTablesVgID(conn unsafe.Pointer, db string, tables []string) (vgIDs [ } vgIDs = make([]int, numTables) for i := 0; i < numTables; i++ { - vgIDs[i] = int(*(*C.int)(unsafe.Pointer(uintptr(p) + uintptr(C.sizeof_int*C.int(i))))) + vgIDs[i] = int(*(*C.int)(pointer.AddUintptr(p, uintptr(C.sizeof_int*C.int(i))))) } return } + +//typedef enum { +//TAOS_CONN_MODE_BI = 0, +//} TAOS_CONN_MODE; +// +//DLL_EXPORT int taos_set_conn_mode(TAOS* taos, int mode, int value); + +func TaosSetConnMode(conn unsafe.Pointer, mode int, value int) int { + return int(C.taos_set_conn_mode(conn, C.int(mode), C.int(value))) +} + +// TaosGetCurrentDB DLL_EXPORT int taos_get_current_db(TAOS *taos, char *database, int len, int *required) +func TaosGetCurrentDB(conn unsafe.Pointer) (db string, err error) { + cDb := C.CString(db) + defer C.free(unsafe.Pointer(cDb)) + var required int + + code := C.taos_get_current_db(conn, cDb, C.int(193), (*C.int)(unsafe.Pointer(&required))) + if code != 0 { + err = errors.NewError(int(code), TaosErrorStr(nil)) + } + db = C.GoString(cDb) + + return +} + +// TaosGetServerInfo DLL_EXPORT const char *taos_get_server_info(TAOS *taos) +func TaosGetServerInfo(conn unsafe.Pointer) string { + info := C.taos_get_server_info(conn) + return C.GoString(info) +} diff --git a/wrapper/taosc_test.go b/wrapper/taosc_test.go index 5e41dfa..18ff226 100644 --- a/wrapper/taosc_test.go +++ b/wrapper/taosc_test.go @@ -153,6 +153,9 @@ func TestTaosQueryA(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos error func TestError(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -166,6 +169,9 @@ func TestError(t *testing.T) { assert.NotEmpty(t, errStr) } +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test affected rows func TestAffectedRows(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -420,11 +426,17 @@ func TestTaosResultBlock(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos_get_client_info func TestTaosGetClientInfo(t *testing.T) { s := TaosGetClientInfo() assert.NotEmpty(t, s) } +// @author: xftan +// @date: 2023/10/13 11:31 +// @description: test taos_load_table_info func TestTaosLoadTableInfo(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -463,6 +475,9 @@ func TestTaosLoadTableInfo(t *testing.T) { } +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test taos_get_table_vgId func TestTaosGetTableVgID(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -510,6 +525,9 @@ func TestTaosGetTableVgID(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test taos_get_tables_vgId func TestTaosGetTablesVgID(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -552,3 +570,38 @@ func TestTaosGetTablesVgID(t *testing.T) { assert.Equal(t, 2, len(vgs2)) assert.Equal(t, vgs2, vgs1) } + +func TestTaosSetConnMode(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + code := TaosSetConnMode(conn, 0, 1) + if code != 0 { + t.Errorf("TaosSetConnMode() error code= %d, msg: %s", code, TaosErrorStr(nil)) + } +} + +func TestTaosGetCurrentDB(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + dbName := "current_db_test" + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + err = exec(conn, fmt.Sprintf("create database %s", dbName)) + assert.NoError(t, err) + defer func() { + _ = exec(conn, fmt.Sprintf("drop database if exists %s", dbName)) + }() + _ = exec(conn, fmt.Sprintf("use %s", dbName)) + db, err := TaosGetCurrentDB(conn) + assert.NoError(t, err) + assert.Equal(t, dbName, db) +} + +func TestTaosGetServerInfo(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + info := TaosGetServerInfo(conn) + assert.NotEmpty(t, info) +} diff --git a/wrapper/tmq.go b/wrapper/tmq.go index 93960ec..949bf10 100644 --- a/wrapper/tmq.go +++ b/wrapper/tmq.go @@ -6,12 +6,16 @@ package wrapper #include #include extern void TMQCommitCB(tmq_t *, int32_t, void *param); +extern void TMQAutoCommitCB(tmq_t *, int32_t, void *param); +extern void TMQCommitOffsetCB(tmq_t *, int32_t, void *param); */ import "C" import ( "sync" "unsafe" + "github.com/taosdata/driver-go/v3/common/pointer" + "github.com/taosdata/driver-go/v3/common/tmq" "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/wrapper/cgo" ) @@ -67,7 +71,7 @@ func TMQConfDestroy(conf unsafe.Pointer) { // TMQConfSetAutoCommitCB DLL_EXPORT void tmq_conf_set_auto_commit_cb(tmq_conf_t *conf, tmq_commit_cb *cb, void *param); func TMQConfSetAutoCommitCB(conf unsafe.Pointer, h cgo.Handle) { - C.tmq_conf_set_auto_commit_cb((*C.struct_tmq_conf_t)(conf), (*C.tmq_commit_cb)(C.TMQCommitCB), h.Pointer()) + C.tmq_conf_set_auto_commit_cb((*C.struct_tmq_conf_t)(conf), (*C.tmq_commit_cb)(C.TMQAutoCommitCB), h.Pointer()) } // TMQCommitAsync DLL_EXPORT void tmq_commit_async(tmq_t *tmq, const TAOS_RES *msg, tmq_commit_cb *cb, void *param); @@ -104,10 +108,10 @@ func TMQListGetSize(list unsafe.Pointer) int32 { // TMQListToCArray char **tmq_list_to_c_array(const tmq_list_t *); func TMQListToCArray(list unsafe.Pointer, size int) []string { - head := uintptr(unsafe.Pointer(C.tmq_list_to_c_array((*C.tmq_list_t)(list)))) + head := unsafe.Pointer(C.tmq_list_to_c_array((*C.tmq_list_t)(list))) result := make([]string, size) for i := 0; i < size; i++ { - result[i] = C.GoString(*(**C.char)(unsafe.Pointer(head + PointerSize*uintptr(i)))) + result[i] = C.GoString(*(**C.char)(pointer.AddUintptr(head, PointerSize*uintptr(i)))) } return result } @@ -249,3 +253,83 @@ func BuildRawMeta(length uint32, metaType uint16, data unsafe.Pointer) unsafe.Po meta.raw_type = (C.uint16_t)(metaType) return unsafe.Pointer(&meta) } + +// TMQGetTopicAssignment DLL_EXPORT int32_t tmq_get_topic_assignment(tmq_t *tmq, const char* pTopicName, tmq_topic_assignment **assignment, int32_t *numOfAssignment) +func TMQGetTopicAssignment(consumer unsafe.Pointer, topic string) (int32, []*tmq.Assignment) { + var assignment *C.tmq_topic_assignment + var numOfAssignment int32 + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + code := int32(C.tmq_get_topic_assignment((*C.tmq_t)(consumer), topicName, (**C.tmq_topic_assignment)(unsafe.Pointer(&assignment)), (*C.int32_t)(&numOfAssignment))) + if code != 0 { + return code, nil + } + if assignment == nil { + return 0, nil + } + defer TMQFreeAssignment(unsafe.Pointer(assignment)) + result := make([]*tmq.Assignment, numOfAssignment) + for i := 0; i < int(numOfAssignment); i++ { + item := *(*C.tmq_topic_assignment)(unsafe.Pointer(uintptr(unsafe.Pointer(assignment)) + uintptr(C.sizeof_struct_tmq_topic_assignment*C.int(i)))) + result[i] = &tmq.Assignment{ + VGroupID: int32(item.vgId), + Offset: int64(item.currentOffset), + Begin: int64(item.begin), + End: int64(item.end), + } + } + return 0, result +} + +// TMQOffsetSeek DLL_EXPORT int32_t tmq_offset_seek(tmq_t* tmq, const char* pTopicName, int32_t vgroupHandle, int64_t offset); +func TMQOffsetSeek(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64) int32 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int32(C.tmq_offset_seek((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset))) +} + +// TMQGetVgroupOffset DLL_EXPORT int64_t tmq_get_vgroup_offset(TAOS_RES* res, int32_t vgroupId); +func TMQGetVgroupOffset(message unsafe.Pointer) int64 { + return int64(C.tmq_get_vgroup_offset(message)) +} + +// TMQFreeAssignment DLL_EXPORT void tmq_free_assignment(tmq_topic_assignment* pAssignment); +func TMQFreeAssignment(assignment unsafe.Pointer) { + if assignment == nil { + return + } + C.tmq_free_assignment((*C.tmq_topic_assignment)(assignment)) +} + +// TMQPosition DLL_EXPORT int64_t tmq_position(tmq_t *tmq, const char *pTopicName, int32_t vgId); +func TMQPosition(consumer unsafe.Pointer, topic string, vGroupID int32) int64 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int64(C.tmq_position((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID))) +} + +// TMQCommitted DLL_EXPORT int64_t tmq_committed(tmq_t *tmq, const char *pTopicName, int32_t vgId); +func TMQCommitted(consumer unsafe.Pointer, topic string, vGroupID int32) int64 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int64(C.tmq_committed((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID))) +} + +// TMQCommitOffsetSync DLL_EXPORT int32_t tmq_commit_offset_sync(tmq_t *tmq, const char *pTopicName, int32_t vgId, int64_t offset); +func TMQCommitOffsetSync(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64) int32 { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + return int32(C.tmq_commit_offset_sync((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset))) +} + +// TMQCommitOffsetAsync DLL_EXPORT void tmq_commit_offset_async(tmq_t *tmq, const char *pTopicName, int32_t vgId, int64_t offset, tmq_commit_cb *cb, void *param); +func TMQCommitOffsetAsync(consumer unsafe.Pointer, topic string, vGroupID int32, offset int64, h cgo.Handle) { + topicName := C.CString(topic) + defer C.free(unsafe.Pointer(topicName)) + C.tmq_commit_offset_async((*C.tmq_t)(consumer), topicName, (C.int32_t)(vGroupID), (C.int64_t)(offset), (*C.tmq_commit_cb)(C.TMQCommitOffsetCB), h.Pointer()) +} + +// TMQGetConnect TAOS *tmq_get_connect(tmq_t *tmq) +func TMQGetConnect(consumer unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(C.tmq_get_connect((*C.tmq_t)(consumer))) +} diff --git a/wrapper/tmq_test.go b/wrapper/tmq_test.go index 3d27e64..b4992cf 100644 --- a/wrapper/tmq_test.go +++ b/wrapper/tmq_test.go @@ -15,6 +15,9 @@ import ( "github.com/taosdata/driver-go/v3/wrapper/cgo" ) +// @author: xftan +// @date: 2023/10/13 11:32 +// @description: test tmq func TestTMQ(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -145,6 +148,7 @@ func TestTMQ(t *testing.T) { // auto commit default is true then the commitCallback function will be called after 5 seconds TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -244,6 +248,9 @@ func TestTMQ(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test TMQList func TestTMQList(t *testing.T) { list := TMQListNew() TMQListAppend(list, "1") @@ -255,6 +262,9 @@ func TestTMQList(t *testing.T) { TMQListDestroy(list) } +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe db func TestTMQDB(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -374,6 +384,7 @@ func TestTMQDB(t *testing.T) { TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -470,6 +481,9 @@ func TestTMQDB(t *testing.T) { assert.GreaterOrEqual(t, totalCount, 5) } +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe multi tables func TestTMQDBMultiTable(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -598,6 +612,7 @@ func TestTMQDBMultiTable(t *testing.T) { TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -704,6 +719,9 @@ func TestTMQDBMultiTable(t *testing.T) { assert.Emptyf(t, tables, "tables name not empty", tables) } +// @author: xftan +// @date: 2023/10/13 11:33 +// @description: test tmq subscribe db with multi table insert func TestTMQDBMultiInsert(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -722,7 +740,7 @@ func TestTMQDBMultiInsert(t *testing.T) { } TaosFreeResult(result) }() - result := TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2 WAL_RETENTION_PERIOD 86400") + result := TaosQuery(conn, "create database if not exists tmq_test_db_multi_insert vgroups 2 wal_retention_period 3600") code := TaosError(result) if code != 0 { errStr := TaosErrorStr(result) @@ -810,6 +828,7 @@ func TestTMQDBMultiInsert(t *testing.T) { TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -909,6 +928,9 @@ func TestTMQDBMultiInsert(t *testing.T) { t.Log(tables) } +// @author: xftan +// @date: 2023/10/13 11:34 +// @description: tmq test modify meta func TestTMQModify(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -956,7 +978,6 @@ func TestTMQModify(t *testing.T) { return } TaosFreeResult(result) - result = TaosQuery(conn, "create database if not exists tmq_test_db_modify_target vgroups 2 WAL_RETENTION_PERIOD 86400") code = TaosError(result) if code != 0 { @@ -1014,6 +1035,7 @@ func TestTMQModify(t *testing.T) { TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -1148,36 +1170,70 @@ func TestTMQModify(t *testing.T) { } d, err := query(targetConn, "describe stb") assert.NoError(t, err) - assert.Equal(t, [][]driver.Value{ - {"ts", "TIMESTAMP", int32(8), ""}, - {"c1", "BOOL", int32(1), ""}, - {"c2", "TINYINT", int32(1), ""}, - {"c3", "SMALLINT", int32(2), ""}, - {"c4", "INT", int32(4), ""}, - {"c5", "BIGINT", int32(8), ""}, - {"c6", "TINYINT UNSIGNED", int32(1), ""}, - {"c7", "SMALLINT UNSIGNED", int32(2), ""}, - {"c8", "INT UNSIGNED", int32(4), ""}, - {"c9", "BIGINT UNSIGNED", int32(8), ""}, - {"c10", "FLOAT", int32(4), ""}, - {"c11", "DOUBLE", int32(8), ""}, - {"c12", "VARCHAR", int32(20), ""}, - {"c13", "NCHAR", int32(20), ""}, - {"tts", "TIMESTAMP", int32(8), "TAG"}, - {"tc1", "BOOL", int32(1), "TAG"}, - {"tc2", "TINYINT", int32(1), "TAG"}, - {"tc3", "SMALLINT", int32(2), "TAG"}, - {"tc4", "INT", int32(4), "TAG"}, - {"tc5", "BIGINT", int32(8), "TAG"}, - {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, - {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, - {"tc8", "INT UNSIGNED", int32(4), "TAG"}, - {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, - {"tc10", "FLOAT", int32(4), "TAG"}, - {"tc11", "DOUBLE", int32(8), "TAG"}, - {"tc12", "VARCHAR", int32(20), "TAG"}, - {"tc13", "NCHAR", int32(20), "TAG"}, - }, d) + if len(d[0]) == 4 { + assert.Equal(t, [][]driver.Value{ + {"ts", "TIMESTAMP", int32(8), ""}, + {"c1", "BOOL", int32(1), ""}, + {"c2", "TINYINT", int32(1), ""}, + {"c3", "SMALLINT", int32(2), ""}, + {"c4", "INT", int32(4), ""}, + {"c5", "BIGINT", int32(8), ""}, + {"c6", "TINYINT UNSIGNED", int32(1), ""}, + {"c7", "SMALLINT UNSIGNED", int32(2), ""}, + {"c8", "INT UNSIGNED", int32(4), ""}, + {"c9", "BIGINT UNSIGNED", int32(8), ""}, + {"c10", "FLOAT", int32(4), ""}, + {"c11", "DOUBLE", int32(8), ""}, + {"c12", "VARCHAR", int32(20), ""}, + {"c13", "NCHAR", int32(20), ""}, + {"tts", "TIMESTAMP", int32(8), "TAG"}, + {"tc1", "BOOL", int32(1), "TAG"}, + {"tc2", "TINYINT", int32(1), "TAG"}, + {"tc3", "SMALLINT", int32(2), "TAG"}, + {"tc4", "INT", int32(4), "TAG"}, + {"tc5", "BIGINT", int32(8), "TAG"}, + {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, + {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, + {"tc8", "INT UNSIGNED", int32(4), "TAG"}, + {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, + {"tc10", "FLOAT", int32(4), "TAG"}, + {"tc11", "DOUBLE", int32(8), "TAG"}, + {"tc12", "VARCHAR", int32(20), "TAG"}, + {"tc13", "NCHAR", int32(20), "TAG"}, + }, d) + } else { + assert.Equal(t, [][]driver.Value{ + {"ts", "TIMESTAMP", int32(8), "", ""}, + {"c1", "BOOL", int32(1), "", ""}, + {"c2", "TINYINT", int32(1), "", ""}, + {"c3", "SMALLINT", int32(2), "", ""}, + {"c4", "INT", int32(4), "", ""}, + {"c5", "BIGINT", int32(8), "", ""}, + {"c6", "TINYINT UNSIGNED", int32(1), "", ""}, + {"c7", "SMALLINT UNSIGNED", int32(2), "", ""}, + {"c8", "INT UNSIGNED", int32(4), "", ""}, + {"c9", "BIGINT UNSIGNED", int32(8), "", ""}, + {"c10", "FLOAT", int32(4), "", ""}, + {"c11", "DOUBLE", int32(8), "", ""}, + {"c12", "VARCHAR", int32(20), "", ""}, + {"c13", "NCHAR", int32(20), "", ""}, + {"tts", "TIMESTAMP", int32(8), "TAG", ""}, + {"tc1", "BOOL", int32(1), "TAG", ""}, + {"tc2", "TINYINT", int32(1), "TAG", ""}, + {"tc3", "SMALLINT", int32(2), "TAG", ""}, + {"tc4", "INT", int32(4), "TAG", ""}, + {"tc5", "BIGINT", int32(8), "TAG", ""}, + {"tc6", "TINYINT UNSIGNED", int32(1), "TAG", ""}, + {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG", ""}, + {"tc8", "INT UNSIGNED", int32(4), "TAG", ""}, + {"tc9", "BIGINT UNSIGNED", int32(8), "TAG", ""}, + {"tc10", "FLOAT", int32(4), "TAG", ""}, + {"tc11", "DOUBLE", int32(8), "TAG", ""}, + {"tc12", "VARCHAR", int32(20), "TAG", ""}, + {"tc13", "NCHAR", int32(20), "TAG", ""}, + }, d) + } + }) TMQUnsubscribe(tmq) @@ -1189,6 +1245,9 @@ func TestTMQModify(t *testing.T) { } } +// @author: xftan +// @date: 2023/10/13 11:34 +// @description: test tmq subscribe with auto create table func TestTMQAutoCreateTable(t *testing.T) { conn, err := TaosConnect("", "root", "taosdata", "", 0) if err != nil { @@ -1273,6 +1332,7 @@ func TestTMQAutoCreateTable(t *testing.T) { TMQConfSet(conf, "enable.auto.commit", "true") TMQConfSet(conf, "group.id", "tg2") TMQConfSet(conf, "msg.with.table.name", "true") + TMQConfSet(conf, "auto.offset.reset", "earliest") c := make(chan *TMQCommitCallbackResult, 1) h := cgo.NewHandle(c) TMQConfSetAutoCommitCB(conf, h) @@ -1379,3 +1439,604 @@ func TestTMQAutoCreateTable(t *testing.T) { } assert.GreaterOrEqual(t, totalCount, 1) } + +// @author: xftan +// @date: 2023/10/13 11:35 +// @description: test tmq get assignment +func TestTMQGetTopicAssignment(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_get_topic_assignment"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_get_topic_assignment vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_get_topic_assignment"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_assignment as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_assignment"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "tg2") + TMQConfSet(conf, "auto.offset.reset", "earliest") + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_assignment") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.Equal(t, int64(0), assignment[0].Offset) + assert.GreaterOrEqual(t, assignment[0].End, assignment[0].Offset) + end := assignment[0].End + vgID, vgCode := TaosGetTableVgID(conn, "test_tmq_get_topic_assignment", "t") + if vgCode != 0 { + t.Fatal(errors.NewError(int(vgCode), TMQErr2Str(code))) + } + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + _ = taosOperation(conn, "insert into t values(now,1)") + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + //seek + code = TMQOffsetSeek(tmq, "test_tmq_assignment", int32(vgID), 0) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.Equal(t, int64(0), assignment[0].Offset) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + haveMessage = false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + // seek twice + code = TMQOffsetSeek(tmq, "test_tmq_assignment", int32(vgID), 1) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) + + haveMessage = false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + offset := TMQGetVgroupOffset(message) + assert.Greater(t, offset, int64(0)) + TMQCommitSync(tmq, message) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + code, assignment = TMQGetTopicAssignment(tmq, "test_tmq_assignment") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + assert.Equal(t, 1, len(assignment)) + assert.Equal(t, int64(0), assignment[0].Begin) + assert.GreaterOrEqual(t, assignment[0].End, end) + end = assignment[0].End + assert.Equal(t, int32(vgID), assignment[0].VGroupID) +} + +func TestTMQPosition(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_position"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_position vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_position"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_position_topic as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_position_topic"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "position") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_position_topic") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_position_topic") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + position := TMQPosition(tmq, "test_tmq_position_topic", vgID) + assert.Equal(t, position, int64(0)) + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, "test_tmq_position_topic", vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, "test_tmq_position_topic", vgID) + assert.Less(t, committed, int64(0)) + TMQCommitSync(tmq, message) + position = TMQPosition(tmq, "test_tmq_position_topic", vgID) + committed = TMQCommitted(tmq, "test_tmq_position_topic", vgID) + assert.Equal(t, position, committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitOffset(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists test_tmq_commit_offset"); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists test_tmq_commit_offset vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use test_tmq_commit_offset"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + if err = taosOperation(conn, "create topic if not exists test_tmq_commit_offset_topic as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists test_tmq_commit_offset_topic"); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, "test_tmq_commit_offset_topic") + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, "test_tmq_commit_offset_topic") + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + code = TMQCommitOffsetSync(tmq, "test_tmq_commit_offset_topic", vgID, offset) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + committed = TMQCommitted(tmq, "test_tmq_commit_offset_topic", vgID) + assert.Equal(t, int64(offset), committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitOffsetAsync(t *testing.T) { + topic := "test_tmq_commit_offset_a_topic" + tableName := "test_tmq_commit_offset_a" + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists "+tableName); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists "+tableName+" vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use "+tableName); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + + if err = taosOperation(conn, "create topic if not exists "+topic+" as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists "+topic); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit_a") + TMQConfSet(conf, "auto.offset.reset", "earliest") + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, topic) + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, topic, vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, topic, vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + c := make(chan *TMQCommitCallbackResult, 1) + handler := cgo.NewHandle(c) + TMQCommitOffsetAsync(tmq, topic, vgID, offset, handler) + timer := time.NewTimer(time.Second * 5) + select { + case r := <-c: + code = r.ErrCode + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + timer.Stop() + case <-timer.C: + t.Fatal("commit async timeout") + timer.Stop() + } + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, int64(offset), committed) + TaosFreeResult(message) + break + } + } + assert.True(t, haveMessage, "expect have message") + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func TestTMQCommitAsyncCallback(t *testing.T) { + topic := "test_tmq_commit_a_cb_topic" + tableName := "test_tmq_commit_a_cb" + conn, err := TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Fatal(err) + return + } + defer TaosClose(conn) + + defer func() { + if err = taosOperation(conn, "drop database if exists "+tableName); err != nil { + t.Error(err) + } + }() + + if err = taosOperation(conn, "create database if not exists "+tableName+" vgroups 1 WAL_RETENTION_PERIOD 86400"); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "use "+tableName); err != nil { + t.Fatal(err) + return + } + if err = taosOperation(conn, "create table if not exists t (ts timestamp,v int)"); err != nil { + t.Fatal(err) + return + } + + // create topic + + if err = taosOperation(conn, "create topic if not exists "+topic+" as select * from t"); err != nil { + t.Fatal(err) + return + } + + defer func() { + if err = taosOperation(conn, "drop topic if exists "+topic); err != nil { + t.Error(err) + } + }() + + conf := TMQConfNew() + defer TMQConfDestroy(conf) + TMQConfSet(conf, "group.id", "commit_a") + TMQConfSet(conf, "enable.auto.commit", "false") + TMQConfSet(conf, "auto.offset.reset", "earliest") + TMQConfSet(conf, "auto.commit.interval.ms", "100") + c := make(chan *TMQCommitCallbackResult, 1) + h := cgo.NewHandle(c) + TMQConfSetAutoCommitCB(conf, h) + go func() { + for r := range c { + t.Log("auto commit", r) + PutTMQCommitCallbackResult(r) + } + }() + + tmq, err := TMQConsumerNew(conf) + if err != nil { + t.Fatal(err) + } + defer TMQConsumerClose(tmq) + + topicList := TMQListNew() + TMQListAppend(topicList, topic) + + errCode := TMQSubscribe(tmq, topicList) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Fatal(errors.NewError(int(errCode), errStr)) + return + } + _ = taosOperation(conn, "insert into t values(now,1)") + code, assignment := TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + vgID := assignment[0].VGroupID + haveMessage := false + for i := 0; i < 3; i++ { + message := TMQConsumerPoll(tmq, 500) + if message != nil { + haveMessage = true + position := TMQPosition(tmq, topic, vgID) + assert.Greater(t, position, int64(0)) + committed := TMQCommitted(tmq, topic, vgID) + assert.Less(t, committed, int64(0)) + offset := TMQGetVgroupOffset(message) + TMQCommitOffsetSync(tmq, topic, vgID, offset) + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, offset, committed) + TaosFreeResult(message) + } + } + assert.True(t, haveMessage, "expect have message") + committed := TMQCommitted(tmq, topic, vgID) + t.Log(committed) + code, assignment = TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + t.Log(assignment[0].Offset) + TMQCommitOffsetSync(tmq, topic, vgID, 1) + committed = TMQCommitted(tmq, topic, vgID) + assert.Equal(t, int64(1), committed) + code, assignment = TMQGetTopicAssignment(tmq, topic) + if code != 0 { + t.Fatal(errors.NewError(int(code), TMQErr2Str(code))) + } + t.Log(assignment[0].Offset) + position := TMQPosition(tmq, topic, vgID) + t.Log(position) + errCode = TMQUnsubscribe(tmq) + if errCode != 0 { + errStr := TMQErr2Str(errCode) + t.Error(errors.NewError(int(errCode), errStr)) + return + } +} + +func taosOperation(conn unsafe.Pointer, sql string) (err error) { + res := TaosQuery(conn, sql) + defer TaosFreeResult(res) + code := TaosError(res) + if code != 0 { + err = errors.NewError(code, TaosErrorStr(res)) + } + return +} diff --git a/wrapper/tmqcb.go b/wrapper/tmqcb.go index 82ae1ee..3893944 100644 --- a/wrapper/tmqcb.go +++ b/wrapper/tmqcb.go @@ -25,3 +25,25 @@ func TMQCommitCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) }() c <- r } + +//export TMQAutoCommitCB +func TMQAutoCommitCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) + r := GetTMQCommitCallbackResult(int32(resp), consumer) + defer func() { + // Avoid panic due to channel closed + recover() + }() + c <- r +} + +//export TMQCommitOffsetCB +func TMQCommitOffsetCB(consumer unsafe.Pointer, resp C.int32_t, param unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *TMQCommitCallbackResult) + r := GetTMQCommitCallbackResult(int32(resp), consumer) + defer func() { + // Avoid panic due to channel closed + recover() + }() + c <- r +} diff --git a/wrapper/whitelist.go b/wrapper/whitelist.go new file mode 100644 index 0000000..32da258 --- /dev/null +++ b/wrapper/whitelist.go @@ -0,0 +1,29 @@ +package wrapper + +/* +#cgo CFLAGS: -IC:/TDengine/include -I/usr/include +#cgo linux LDFLAGS: -L/usr/lib -ltaos +#cgo windows LDFLAGS: -LC:/TDengine/driver -ltaos +#cgo darwin LDFLAGS: -L/usr/local/lib -ltaos +#include +#include +#include +#include +extern void WhitelistCallback(void *param, int code, TAOS *taos, int numOfWhiteLists, uint64_t* pWhiteLists); +void taos_fetch_whitelist_a_wrapper(TAOS *taos, void *param){ + return taos_fetch_whitelist_a(taos, WhitelistCallback, param); +}; +*/ +import "C" +import ( + "unsafe" + + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +// typedef void (*__taos_async_whitelist_fn_t)(void *param, int code, TAOS *taos, int numOfWhiteLists, uint64_t* pWhiteLists); + +// TaosFetchWhitelistA DLL_EXPORT void taos_fetch_whitelist_a(TAOS *taos, __taos_async_whitelist_fn_t fp, void *param); +func TaosFetchWhitelistA(taosConnect unsafe.Pointer, caller cgo.Handle) { + C.taos_fetch_whitelist_a_wrapper(taosConnect, caller.Pointer()) +} diff --git a/wrapper/whitelist_test.go b/wrapper/whitelist_test.go new file mode 100644 index 0000000..26a3a6e --- /dev/null +++ b/wrapper/whitelist_test.go @@ -0,0 +1,21 @@ +package wrapper + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +func TestGetWhiteList(t *testing.T) { + conn, err := TaosConnect("", "root", "taosdata", "", 0) + assert.NoError(t, err) + defer TaosClose(conn) + c := make(chan *WhitelistResult, 1) + handler := cgo.NewHandle(c) + TaosFetchWhitelistA(conn, handler) + data := <-c + assert.Equal(t, int32(0), data.ErrCode) + assert.Equal(t, 1, len(data.IPNets)) + assert.Equal(t, "0.0.0.0/0", data.IPNets[0].String()) +} diff --git a/wrapper/whitelistcb.go b/wrapper/whitelistcb.go new file mode 100644 index 0000000..34710c8 --- /dev/null +++ b/wrapper/whitelistcb.go @@ -0,0 +1,35 @@ +package wrapper + +import "C" +import ( + "net" + "unsafe" + + "github.com/taosdata/driver-go/v3/wrapper/cgo" +) + +type WhitelistResult struct { + ErrCode int32 + IPNets []*net.IPNet +} + +//export WhitelistCallback +func WhitelistCallback(param unsafe.Pointer, code int, taosConnect unsafe.Pointer, numOfWhiteLists int, pWhiteLists unsafe.Pointer) { + c := (*(*cgo.Handle)(param)).Value().(chan *WhitelistResult) + if code != 0 { + c <- &WhitelistResult{ErrCode: int32(code)} + return + } + ips := make([]*net.IPNet, 0, numOfWhiteLists) + for i := 0; i < numOfWhiteLists; i++ { + ipNet := make([]byte, 8) + for j := 0; j < 8; j++ { + ipNet[j] = *(*byte)(unsafe.Pointer(uintptr(pWhiteLists) + uintptr(i*8) + uintptr(j))) + } + ip := net.IP{ipNet[0], ipNet[1], ipNet[2], ipNet[3]} + ones := int(ipNet[4]) + ipMask := net.CIDRMask(ones, 32) + ips = append(ips, &net.IPNet{IP: ip, Mask: ipMask}) + } + c <- &WhitelistResult{IPNets: ips} +} diff --git a/ws/client/conn.go b/ws/client/conn.go index 493a7f1..37a5dd2 100644 --- a/ws/client/conn.go +++ b/ws/client/conn.go @@ -87,7 +87,7 @@ func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { } func (c *Client) ReadPump() { - c.conn.SetReadLimit(common.BufferSize4M) + c.conn.SetReadLimit(0) c.conn.SetReadDeadline(time.Now().Add(c.PongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(c.PongWait)) @@ -105,9 +105,9 @@ func (c *Client) ReadPump() { } switch messageType { case websocket.TextMessage: - c.TextMessageHandler(message) + go c.TextMessageHandler(message) case websocket.BinaryMessage: - c.BinaryMessageHandler(message) + go c.BinaryMessageHandler(message) } } } diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go new file mode 100644 index 0000000..58f65b0 --- /dev/null +++ b/ws/schemaless/config.go @@ -0,0 +1,66 @@ +package schemaless + +import ( + "time" +) + +const ( + connAction = "conn" + insertAction = "insert" +) + +type Config struct { + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) +} + +func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { + c := Config{url: url, chanLength: chanLength} + for _, opt := range opts { + opt(&c) + } + + return &c +} + +func SetUser(user string) func(*Config) { + return func(c *Config) { + c.user = user + } +} + +func SetPassword(password string) func(*Config) { + return func(c *Config) { + c.password = password + } +} + +func SetDb(db string) func(*Config) { + return func(c *Config) { + c.db = db + } +} + +func SetReadTimeout(readTimeout time.Duration) func(*Config) { + return func(c *Config) { + c.readTimeout = readTimeout + } +} + +func SetWriteTimeout(writeTimeout time.Duration) func(*Config) { + return func(c *Config) { + c.writeTimeout = writeTimeout + } +} + +func SetErrorHandler(errorHandler func(error)) func(*Config) { + return func(c *Config) { + c.errorHandler = errorHandler + } +} diff --git a/ws/schemaless/proto.go b/ws/schemaless/proto.go new file mode 100644 index 0000000..d75cb14 --- /dev/null +++ b/ws/schemaless/proto.go @@ -0,0 +1,33 @@ +package schemaless + +type wsConnectReq struct { + ReqID uint64 `json:"req_id"` + User string `json:"user"` + Password string `json:"password"` + DB string `json:"db"` +} + +type wsConnectResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +type schemalessReq struct { + ReqID uint64 `json:"req_id"` + DB string `json:"db"` + Protocol int `json:"protocol"` + Precision string `json:"precision"` + TTL int `json:"ttl"` + Data string `json:"data"` +} + +type schemalessResp struct { + Code int `json:"code"` + Message string `json:"message"` + ReqID uint64 `json:"req_id"` + Action string `json:"action"` + Timing int64 `json:"timing"` +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go new file mode 100644 index 0000000..db44e09 --- /dev/null +++ b/ws/schemaless/schemaless.go @@ -0,0 +1,262 @@ +package schemaless + +import ( + "container/list" + "context" + "errors" + "fmt" + "net/url" + "sync" + "time" + + "github.com/gorilla/websocket" + jsoniter "github.com/json-iterator/go" + "github.com/taosdata/driver-go/v3/common" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +const ( + InfluxDBLineProtocol = 1 + OpenTSDBTelnetLineProtocol = 2 + OpenTSDBJsonFormatProtocol = 3 +) + +type Schemaless struct { + client *client.Client + sendList *list.List + url string + user string + password string + db string + readTimeout time.Duration + lock sync.Mutex + once sync.Once + closeChan chan struct{} + errorHandler func(error) +} + +func NewSchemaless(config *Config) (*Schemaless, error) { + wsUrl, err := url.Parse(config.url) + if err != nil { + return nil, fmt.Errorf("config url error: %s", err) + } + if wsUrl.Scheme != "ws" && wsUrl.Scheme != "wss" { + return nil, errors.New("config url scheme error") + } + if len(wsUrl.Path) == 0 || wsUrl.Path != "/rest/schemaless" { + wsUrl.Path = "/rest/schemaless" + } + ws, _, err := common.DefaultDialer.Dial(wsUrl.String(), nil) + if err != nil { + return nil, fmt.Errorf("dial ws error: %s", err) + } + + s := Schemaless{ + client: client.NewClient(ws, config.chanLength), + sendList: list.New(), + url: config.url, + user: config.user, + password: config.password, + db: config.db, + closeChan: make(chan struct{}), + errorHandler: config.errorHandler, + } + + if config.readTimeout > 0 { + s.readTimeout = config.readTimeout + } + + if config.writeTimeout > 0 { + s.client.WriteWait = config.writeTimeout + } + s.client.ErrorHandler = s.handleError + s.client.TextMessageHandler = s.handleTextMessage + + go s.client.ReadPump() + go s.client.WritePump() + + if err = s.connect(); err != nil { + return nil, fmt.Errorf("connect ws error: %s", err) + } + + return &s, nil +} + +func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl int, reqID int64) error { + if reqID == 0 { + reqID = common.GetReqID() + } + req := &schemalessReq{ + ReqID: uint64(reqID), + DB: s.db, + Protocol: protocol, + Precision: precision, + TTL: ttl, + Data: lines, + } + + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{Action: insertAction, Args: args} + envelope := s.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.client.PutEnvelope(envelope) + return err + } + respBytes, err := s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } + var resp schemalessResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Schemaless) Close() { + s.once.Do(func() { + close(s.closeChan) + if s.client != nil { + s.client.Close() + } + s.client = nil + }) +} + +func (s *Schemaless) connect() error { + reqID := uint64(common.GetReqID()) + req := &wsConnectReq{ + ReqID: reqID, + User: s.user, + Password: s.password, + DB: s.db, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: connAction, + Args: args, + } + envelope := s.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.client.PutEnvelope(envelope) + return err + } + + respBytes, err := s.sendText(reqID, envelope) + if err != nil { + return err + } + var resp wsConnectResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (s *Schemaless) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { + envelope.Type = websocket.TextMessage + return s.send(reqID, envelope) +} + +func (s *Schemaless) send(reqID uint64, envelope *client.Envelope) ([]byte, error) { + channel := &IndexedChan{ + index: reqID, + channel: make(chan []byte, 1), + } + element := s.addMessageOutChan(channel) + s.client.Send(envelope) + ctx, cancel := context.WithTimeout(context.Background(), s.readTimeout) + defer cancel() + select { + case <-s.closeChan: + return nil, errors.New("connection closed") + case resp := <-channel.channel: + return resp, nil + case <-ctx.Done(): + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, fmt.Errorf("message timeout :%s", envelope.Msg.String()) + } +} + +type IndexedChan struct { + index uint64 + channel chan []byte +} + +func (s *Schemaless) addMessageOutChan(outChan *IndexedChan) *list.Element { + s.lock.Lock() + defer s.lock.Unlock() + element := s.sendList.PushBack(outChan) + return element +} + +func (s *Schemaless) handleTextMessage(message []byte) { + iter := client.JsonI.BorrowIterator(message) + var reqID uint64 + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "req_id": + reqID = iter.ReadUint64() + return false + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + s.lock.Lock() + defer s.lock.Unlock() + + element := s.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + s.sendList.Remove(element) + } +} + +func (s *Schemaless) findOutChanByID(index uint64) *list.Element { + root := s.sendList.Front() + if root == nil { + return nil + } + rootIndex := root.Value.(*IndexedChan).index + if rootIndex == index { + return root + } + item := root.Next() + for { + if item == nil || item == root { + return nil + } + if item.Value.(*IndexedChan).index == index { + return item + } + item = item.Next() + } +} + +func (s *Schemaless) handleError(err error) { + if s.errorHandler != nil { + s.errorHandler(err) + } + s.Close() +} diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go new file mode 100644 index 0000000..bfa3199 --- /dev/null +++ b/ws/schemaless/schemaless_test.go @@ -0,0 +1,134 @@ +package schemaless + +import ( + "fmt" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" + + jsoniter "github.com/json-iterator/go" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +// @author: xftan +// @date: 2023/10/13 11:35 +// @description: test websocket schemaless insert +func TestSchemaless_Insert(t *testing.T) { + cases := []struct { + name string + protocol int + precision string + data string + ttl int + code int + }{ + { + name: "influxdb", + protocol: InfluxDBLineProtocol, + precision: "ms", + data: "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000", + ttl: 1000, + }, + { + name: "opentsdb_telnet", + protocol: OpenTSDBTelnetLineProtocol, + precision: "ms", + data: "meters.current 1648432611249 10.3 location=California.SanFrancisco group=2\n" + + "meters.current 1648432611250 12.6 location=California.SanFrancisco group=2\n" + + "meters.current 1648432611251 10.8 location=California.LosAngeles group=3\n" + + "meters.current 1648432611252 11.3 location=California.LosAngeles group=3\n", + ttl: 1000, + }, + { + name: "opentsdb_json", + protocol: OpenTSDBJsonFormatProtocol, + precision: "ms", + data: "[{\"metric\": \"meters.voltage\", \"timestamp\": 1648432611249, \"value\": 219, \"tags\": " + + "{\"location\": \"California.LosAngeles\", \"groupid\": 1 } }, {\"metric\": \"meters.voltage\", " + + "\"timestamp\": 1648432611250, \"value\": 221, \"tags\": {\"location\": \"California.LosAngeles\", " + + "\"groupid\": 1 } }]", + ttl: 100, + }, + } + + if err := before(); err != nil { + t.Fatal(err) + } + defer func() { _ = after() }() + + s, err := NewSchemaless(NewConfig("ws://localhost:6041/rest/schemaless", 1, + SetDb("test_schemaless_ws"), + SetReadTimeout(10*time.Second), + SetWriteTimeout(10*time.Second), + SetUser("root"), + SetPassword("taosdata"), + SetErrorHandler(func(err error) { + t.Fatal(err) + }), + )) + if err != nil { + t.Fatal(err) + } + //defer s.Close() + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if err := s.Insert(c.data, c.protocol, c.precision, c.ttl, 0); err != nil { + t.Fatal(err) + } + }) + } +} + +func doRequest(sql string) error { + req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:6041/rest/sql", strings.NewReader(sql)) + req.Header.Set("Authorization", "Taosd /KfeAzX/f9na8qdtNZmtONryp201ma04bEl8LcvLUd7a8qdtNZmtONryp201ma04") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("http code: %d", resp.StatusCode) + } + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + iter := client.JsonI.BorrowIterator(data) + code := int32(0) + desc := "" + iter.ReadObjectCB(func(iter *jsoniter.Iterator, s string) bool { + switch s { + case "code": + code = iter.ReadInt32() + case "desc": + desc = iter.ReadString() + default: + iter.Skip() + } + return iter.Error == nil + }) + client.JsonI.ReturnIterator(iter) + if code != 0 { + return taosErrors.NewError(int(code), desc) + } + return nil +} + +func before() error { + if err := doRequest("drop database if exists test_schemaless_ws"); err != nil { + return err + } + return doRequest("create database if not exists test_schemaless_ws") +} + +func after() error { + return doRequest("drop database test_schemaless_ws") +} diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index f12d58e..4cc1631 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -147,6 +147,9 @@ func query(payload string) (*common.TDEngineRestfulResp, error) { return marshalBody(resp.Body, 512) } +// @author: xftan +// @date: 2023/10/13 11:35 +// @description: test stmt over websocket func TestStmt(t *testing.T) { err := prepareEnv() if err != nil { diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 1441f5a..88ed25b 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -132,7 +132,7 @@ func (c *config) setSnapshotEnable(enableSnapshot tmq.ConfigValue) error { func (c *config) setWithTableName(withTableName tmq.ConfigValue) error { var ok bool - c.SnapshotEnable, ok = withTableName.(string) + c.WithTableName, ok = withTableName.(string) if !ok { return fmt.Errorf("msg.with.table.name requires string got %T", withTableName) } diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 24a5c23..df1889b 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -40,6 +40,7 @@ type Consumer struct { withTableName string closeOnce sync.Once closeChan chan struct{} + topics []string } type IndexedChan struct { @@ -281,13 +282,19 @@ func (c *Consumer) findOutChanByID(index uint64) *list.Element { } const ( - TMQSubscribe = "subscribe" - TMQPoll = "poll" - TMQFetch = "fetch" - TMQFetchBlock = "fetch_block" - TMQFetchJsonMeta = "fetch_json_meta" - TMQCommit = "commit" - TMQUnsubscribe = "unsubscribe" + TMQSubscribe = "subscribe" + TMQPoll = "poll" + TMQFetch = "fetch" + TMQFetchBlock = "fetch_block" + TMQFetchJsonMeta = "fetch_json_meta" + TMQCommit = "commit" + TMQUnsubscribe = "unsubscribe" + TMQGetTopicAssignment = "assignment" + TMQSeek = "seek" + TMQCommitOffset = "commit_offset" + TMQCommitted = "committed" + TMQPosition = "position" + TMQListTopics = "list_topics" ) var ClosedErr = errors.New("connection closed") @@ -369,6 +376,8 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err if resp.Code != 0 { return taosErrors.NewError(resp.Code, resp.Message) } + c.topics = make([]string, len(topics)) + copy(c.topics, topics) return nil } @@ -415,26 +424,41 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { result := &tmq.DataMessage{} result.SetDbName(resp.Database) result.SetTopic(resp.Topic) + result.SetOffset(tmq.Offset(resp.Offset)) data, err := c.fetch(resp.MessageID) if err != nil { return tmq.NewTMQErrorWithErr(err) } result.SetData(data) + topic := resp.Topic + result.TopicPartition = tmq.TopicPartition{ + Topic: &topic, + Partition: resp.VgroupID, + Offset: tmq.Offset(resp.Offset), + } return result case common.TMQ_RES_TABLE_META: result := &tmq.MetaMessage{} result.SetDbName(resp.Database) result.SetTopic(resp.Topic) + result.SetOffset(tmq.Offset(resp.Offset)) meta, err := c.fetchJsonMeta(resp.MessageID) if err != nil { return tmq.NewTMQErrorWithErr(err) } + topic := resp.Topic + result.TopicPartition = tmq.TopicPartition{ + Topic: &topic, + Partition: resp.VgroupID, + Offset: tmq.Offset(resp.Offset), + } result.SetMeta(meta) return result case common.TMQ_RES_METADATA: result := &tmq.MetaDataMessage{} result.SetDbName(resp.Database) result.SetTopic(resp.Topic) + result.SetOffset(tmq.Offset(resp.Offset)) meta, err := c.fetchJsonMeta(resp.MessageID) if err != nil { return tmq.NewTMQErrorWithErr(err) @@ -447,6 +471,12 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { Meta: meta, Data: data, }) + topic := resp.Topic + result.TopicPartition = tmq.TopicPartition{ + Topic: &topic, + Partition: resp.VgroupID, + Offset: tmq.Offset(resp.Offset), + } return result default: return tmq.NewTMQErrorWithErr(err) @@ -558,7 +588,8 @@ func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { return nil, err } block := respBytes[24:] - data := parser.ReadBlock(unsafe.Pointer(*(*uintptr)(unsafe.Pointer(&block))), resp.Rows, resp.FieldsTypes, resp.Precision) + p := unsafe.Pointer(&block[0]) + data := parser.ReadBlock(p, resp.Rows, resp.FieldsTypes, resp.Precision) tmqData = append(tmqData, &tmq.Data{ TableName: resp.TableName, Data: data, @@ -607,7 +638,11 @@ func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { if resp.Code != 0 { return nil, taosErrors.NewError(resp.Code, resp.Message) } - return nil, nil + partitions, err := c.Assignment() + if err != nil { + return nil, err + } + return c.Committed(partitions, 0) } func (c *Consumer) Unsubscribe() error { @@ -646,3 +681,232 @@ func (c *Consumer) Unsubscribe() error { } return nil } + +func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { + if c.err != nil { + return nil, c.err + } + for _, topic := range c.topics { + reqID := c.generateReqID() + req := &AssignmentReq{ + ReqID: reqID, + Topic: topic, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQGetTopicAssignment, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp AssignmentResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + topicName := topic + for i := 0; i < len(resp.Assignment); i++ { + offset := tmq.Offset(resp.Assignment[i].Offset) + partitions = append(partitions, tmq.TopicPartition{ + Topic: &topicName, + Partition: resp.Assignment[i].VGroupID, + Offset: offset, + }) + } + } + return partitions, nil +} + +func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) error { + if c.err != nil { + return c.err + } + reqID := c.generateReqID() + req := &OffsetSeekReq{ + ReqID: reqID, + Topic: *partition.Topic, + VgroupID: partition.Partition, + Offset: int64(partition.Offset), + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: TMQSeek, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return err + } + var resp OffsetSeekResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (offsets []tmq.TopicPartition, err error) { + offsets = make([]tmq.TopicPartition, len(partitions)) + reqID := c.generateReqID() + req := &CommittedReq{ + ReqID: reqID, + TopicVgroupIDs: make([]TopicVgroupID, len(partitions)), + } + for i := 0; i < len(partitions); i++ { + req.TopicVgroupIDs[i] = TopicVgroupID{ + Topic: *partitions[i].Topic, + VgroupID: partitions[i].Partition, + } + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQCommitted, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp CommittedResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + for i := 0; i < len(resp.Committed); i++ { + offsets[i] = tmq.TopicPartition{ + Topic: partitions[i].Topic, + Partition: partitions[i].Partition, + Offset: tmq.Offset(resp.Committed[i]), + } + } + return offsets, nil +} + +func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicPartition, error) { + if c.err != nil { + return nil, c.err + } + for i := 0; i < len(offsets); i++ { + reqID := c.generateReqID() + req := &CommitOffsetReq{ + ReqID: reqID, + Topic: *offsets[i].Topic, + VgroupID: offsets[i].Partition, + Offset: int64(offsets[i].Offset), + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQCommitOffset, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp CommitOffsetResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + } + return c.Committed(offsets, 0) +} + +func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.TopicPartition, err error) { + offsets = make([]tmq.TopicPartition, len(partitions)) + reqID := c.generateReqID() + req := &PositionReq{ + ReqID: reqID, + TopicVgroupIDs: make([]TopicVgroupID, len(partitions)), + } + for i := 0; i < len(partitions); i++ { + req.TopicVgroupIDs[i] = TopicVgroupID{ + Topic: *partitions[i].Topic, + VgroupID: partitions[i].Partition, + } + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQPosition, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp PositionResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + for i := 0; i < len(resp.Position); i++ { + offsets[i] = tmq.TopicPartition{ + Topic: partitions[i].Topic, + Partition: partitions[i].Partition, + Offset: tmq.Offset(resp.Position[i]), + } + } + return offsets, nil +} diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 3b948b5..a7bf95e 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -22,7 +22,7 @@ func prepareEnv() error { "drop topic if exists test_ws_tmq_topic", "drop database if exists test_ws_tmq", "create database test_ws_tmq WAL_RETENTION_PERIOD 86400", - "create topic test_ws_tmq_topic with meta as database test_ws_tmq", + "create topic test_ws_tmq_topic as database test_ws_tmq", } for _, step := range steps { err = doRequest(step) @@ -86,6 +86,9 @@ func doRequest(payload string) error { return nil } +// @author: xftan +// @date: 2023/10/13 11:36 +// @description: test tmq subscribe over websocket func TestConsumer(t *testing.T) { err := prepareEnv() if err != nil { @@ -121,19 +124,18 @@ func TestConsumer(t *testing.T) { } }() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", - "ws.message.channelLen": uint(0), - "ws.message.timeout": common.DefaultMessageTimeout, - "ws.message.writeWait": common.DefaultWriteWait, - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "group.id": "test", - "client.id": "test_consumer", - "auto.offset.reset": "earliest", - "enable.auto.commit": "true", - "auto.commit.interval.ms": "5000", - "experimental.snapshot.enable": "true", - "msg.with.table.name": "true", + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "5000", + "msg.with.table.name": "true", }) if err != nil { t.Error(err) @@ -146,10 +148,9 @@ func TestConsumer(t *testing.T) { t.Error(err) return } - gotMeta := false gotData := false for i := 0; i < 5; i++ { - if gotData && gotMeta { + if gotData { return } ev := consumer.Poll(0) @@ -177,84 +178,23 @@ func TestConsumer(t *testing.T) { assert.Equal(t, float64(11.123), v[11].(float64)) assert.Equal(t, "binary", v[12].(string)) assert.Equal(t, "nchar", v[13].(string)) - case *tmq.MetaMessage: - gotMeta = true - meta := e.Value().(*tmq.Meta) - assert.Equal(t, "test_ws_tmq", e.DBName()) - assert.Equal(t, "create", meta.Type) - assert.Equal(t, "t_all", meta.TableName) - assert.Equal(t, "normal", meta.TableType) - assert.Equal(t, []*tmq.Column{ - { - Name: "ts", - Type: 9, - Length: 0, - }, - { - Name: "c1", - Type: 1, - Length: 0, - }, - { - Name: "c2", - Type: 2, - Length: 0, - }, - { - Name: "c3", - Type: 3, - Length: 0, - }, - { - Name: "c4", - Type: 4, - Length: 0, - }, - { - Name: "c5", - Type: 5, - Length: 0, - }, - { - Name: "c6", - Type: 11, - Length: 0, - }, - { - Name: "c7", - Type: 12, - Length: 0, - }, - { - Name: "c8", - Type: 13, - Length: 0, - }, - { - Name: "c9", - Type: 14, - Length: 0, - }, - { - Name: "c10", - Type: 6, - Length: 0, - }, - { - Name: "c11", - Type: 7, - Length: 0, - }, - { - Name: "c12", - Type: 8, - Length: 20, - }, - { - Name: "c13", - Type: 10, - Length: 20, - }}, meta.Columns) + t.Log(e.Offset()) + ass, err := consumer.Assignment() + t.Log(ass) + committed, err := consumer.Committed(ass, 0) + t.Log(committed) + position, _ := consumer.Position(ass) + t.Log(position) + offsets, err := consumer.Position([]tmq.TopicPartition{e.TopicPartition}) + assert.NoError(t, err) + _, err = consumer.CommitOffsets(offsets) + assert.NoError(t, err) + ass, err = consumer.Assignment() + t.Log(ass) + committed, err = consumer.Committed(ass, 0) + t.Log(committed) + position, _ = consumer.Position(ass) + t.Log(position) case tmq.Error: t.Error(e) return @@ -262,7 +202,7 @@ func TestConsumer(t *testing.T) { t.Error("unexpected", e) return } - _, err = consumer.Commit() + } if err != nil { @@ -270,9 +210,6 @@ func TestConsumer(t *testing.T) { return } } - if !gotMeta { - t.Error("no meta got") - } if !gotData { t.Error("no data got") } @@ -282,3 +219,134 @@ func TestConsumer(t *testing.T) { return } } + +func prepareSeekEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_seek_topic", + "drop database if exists test_ws_tmq_seek", + "create database test_ws_tmq_seek vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_seek_topic as database test_ws_tmq_seek", + "create table test_ws_tmq_seek.t1(ts timestamp,v int)", + "insert into test_ws_tmq_seek.t1 values (now,1)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanSeekEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_seek_topic", + "drop database if exists test_ws_tmq_seek", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +// @author: xftan +// @date: 2023/10/13 11:36 +// @description: test tmq seek over websocket +func TestSeek(t *testing.T) { + err := prepareSeekEnv() + if err != nil { + t.Error(err) + return + } + defer cleanSeekEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "false", + "experimental.snapshot.enable": "false", + "msg.with.table.name": "true", + }) + if err != nil { + t.Error(err) + return + } + defer consumer.Close() + topic := []string{"test_ws_tmq_seek_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + partitions, err := consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) + assert.Equal(t, tmq.Offset(0), partitions[0].Offset) + + //poll + messageOffset := tmq.Offset(0) + haveMessage := false + for i := 0; i < 5; i++ { + event := consumer.Poll(500) + if event != nil { + haveMessage = true + _, err = consumer.Commit() + assert.NoError(t, err) + messageOffset = event.(*tmq.DataMessage).Offset() + } + } + assert.True(t, haveMessage) + partitions, err = consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) + assert.GreaterOrEqual(t, partitions[0].Offset, messageOffset) + + //seek + tmpTopic := "test_ws_tmq_seek_topic" + err = consumer.Seek(tmq.TopicPartition{ + Topic: &tmpTopic, + Partition: partitions[0].Partition, + Offset: 0, + }, 0) + assert.NoError(t, err) + + //assignment + partitions, err = consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) + assert.Equal(t, tmq.Offset(0), partitions[0].Offset) + + //poll + messageOffset = tmq.Offset(0) + haveMessage = false + for i := 0; i < 5; i++ { + event := consumer.Poll(500) + if event != nil { + haveMessage = true + messageOffset = event.(*tmq.DataMessage).Offset() + _, err = consumer.Commit() + assert.NoError(t, err) + } + } + partitions, err = consumer.Assignment() + assert.True(t, haveMessage) + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) + assert.GreaterOrEqual(t, partitions[0].Offset, messageOffset) +} diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go index a5376eb..d9b8c1d 100644 --- a/ws/tmq/proto.go +++ b/ws/tmq/proto.go @@ -1,6 +1,10 @@ package tmq -import "encoding/json" +import ( + "encoding/json" + + "github.com/taosdata/driver-go/v3/common/tmq" +) type SubscribeReq struct { ReqID uint64 `json:"req_id"` @@ -42,6 +46,7 @@ type PollResp struct { VgroupID int32 `json:"vgroup_id"` MessageType int32 `json:"message_type"` MessageID uint64 `json:"message_id"` + Offset int64 `json:"offset"` } type FetchJsonMetaReq struct { @@ -111,3 +116,83 @@ type UnsubscribeResp struct { ReqID uint64 `json:"req_id"` Timing int64 `json:"timing"` } + +type AssignmentReq struct { + ReqID uint64 `json:"req_id"` + Topic string `json:"topic"` +} + +type AssignmentResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Assignment []tmq.Assignment `json:"assignment"` +} + +type OffsetSeekReq struct { + ReqID uint64 `json:"req_id"` + Topic string `json:"topic"` + VgroupID int32 `json:"vgroup_id"` + Offset int64 `json:"offset"` +} + +type OffsetSeekResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` +} + +type CommittedReq struct { + ReqID uint64 `json:"req_id"` + TopicVgroupIDs []TopicVgroupID `json:"topic_vgroup_ids"` +} + +type CommittedResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Committed []int64 `json:"committed"` +} + +type TopicVgroupID struct { + Topic string `json:"topic"` + VgroupID int32 `json:"vgroup_id"` +} + +type CommitOffsetReq struct { + ReqID uint64 `json:"req_id"` + Topic string `json:"topic"` + VgroupID int32 `json:"vgroup_id"` + Offset int64 `json:"offset"` +} + +type CommitOffsetResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Topic string `json:"topic"` + VgroupID int32 `json:"vgroup_id"` + Offset int64 `json:"offset"` +} + +type PositionReq struct { + ReqID uint64 `json:"req_id"` + TopicVgroupIDs []TopicVgroupID `json:"topic_vgroup_ids"` +} + +type PositionResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + Position []int64 `json:"position"` +}