diff --git a/.github/workflows/compatibility.yml b/.github/workflows/compatibility.yml new file mode 100644 index 0000000..956433b --- /dev/null +++ b/.github/workflows/compatibility.yml @@ -0,0 +1,140 @@ +name: compatibility + +on: + pull_request: + branches: + - '3.1' + +jobs: + build: + runs-on: ubuntu-22.04 + strategy: + matrix: + td_version: [ 'main', '3.0' ] + name: Build ${{ matrix.td_version }} + outputs: + commit_id: ${{ steps.get_commit_id.outputs.commit_id }} + steps: + - name: checkout TDengine by pr + if: github.event_name == 'pull_request' + uses: actions/checkout@v3 + with: + repository: 'taosdata/TDengine' + path: 'TDengine' + ref: ${{ matrix.td_version }} + + - name: get_commit_id + id: get_commit_id + run: | + cd TDengine + echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT + + + - name: Cache server by pr + if: github.event_name == 'pull_request' + id: cache-server-pr + uses: actions/cache@v3 + with: + path: server.tar.gz + key: ${{ runner.os }}-build-${{ matrix.td_version }}-${{ steps.get_commit_id.outputs.commit_id }} + + - name: prepare install + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: sudo apt install -y libgeos-dev + + - name: install TDengine + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: | + cd TDengine + mkdir debug + cd debug + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 + make -j 4 + + - name: package + if: > + (github.event_name == 'pull_request' && steps.cache-server-pr.outputs.cache-hit != 'true') + run: | + mkdir -p ./release + cp ./TDengine/debug/build/bin/taos ./release/ + cp ./TDengine/debug/build/bin/taosd ./release/ + cp ./TDengine/tools/taosadapter/taosadapter ./release/ + cp ./TDengine/debug/build/lib/libtaos.so.3.9.9.9 ./release/ + cp ./TDengine/debug/build/lib/librocksdb.so.8.1.1 ./release/ ||: + cp ./TDengine/include/client/taos.h ./release/ + cat >./release/install.sh<start.sh<start.sh<start.sh< 64*1024 { + e.Msg = new(bytes.Buffer) + } else { + e.Msg.Reset() + } + if len(e.ErrorChan) > 0 { + e.ErrorChan = make(chan error, 1) + } } +var ClosedError = errors.New("websocket closed") + type Client struct { conn *websocket.Conn status uint32 @@ -63,9 +74,10 @@ type Client struct { TextMessageHandler func(message []byte) BinaryMessageHandler func(message []byte) ErrorHandler func(err error) - SendMessageHandler func(envelope *Envelope) - once sync.Once - errHandlerOnce sync.Once + //SendMessageHandler func(envelope *Envelope) + once sync.Once + errHandlerOnce sync.Once + err error } func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { @@ -80,9 +92,9 @@ func NewClient(conn *websocket.Conn, sendChanLength uint) *Client { TextMessageHandler: func(message []byte) {}, BinaryMessageHandler: func(message []byte) {}, ErrorHandler: func(err error) {}, - SendMessageHandler: func(envelope *Envelope) { - GlobalEnvelopePool.Put(envelope) - }, + //SendMessageHandler: func(envelope *Envelope) { + // GlobalEnvelopePool.Put(envelope) + //}, } } @@ -117,41 +129,61 @@ func (c *Client) WritePump() { defer func() { ticker.Stop() }() + for { select { case message, ok := <-c.sendChan: if !ok { - return + if message == nil { + return + } + message.ErrorChan <- ClosedError + continue } c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) err := c.conn.WriteMessage(message.Type, message.Msg.Bytes()) if err != nil { + message.ErrorChan <- err c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } - c.SendMessageHandler(message) + message.ErrorChan <- nil case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(c.WriteWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.handleError(err) - return + c.Close() + for message := range c.sendChan { + if message == nil { + return + } + message.ErrorChan <- ClosedError + } } } } } -func (c *Client) Send(envelope *Envelope) { +func (c *Client) Send(envelope *Envelope) error { if !c.IsRunning() { - return + return ClosedError } + var err error defer func() { // maybe closed if recover() != nil { - + err = ClosedError return } }() c.sendChan <- envelope + return err } func (c *Client) GetEnvelope() *Envelope { @@ -168,8 +200,8 @@ func (c *Client) IsRunning() bool { func (c *Client) Close() { c.once.Do(func() { - close(c.sendChan) atomic.StoreUint32(&c.status, StatusStop) + close(c.sendChan) if c.conn != nil { c.conn.Close() } diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go index d62eb3b..7599984 100644 --- a/ws/schemaless/config.go +++ b/ws/schemaless/config.go @@ -10,19 +10,22 @@ const ( ) type Config struct { - url string - chanLength uint - user string - password string - db string - readTimeout time.Duration - writeTimeout time.Duration - errorHandler func(error) - enableCompression bool + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) + enableCompression bool + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { - c := Config{url: url, chanLength: chanLength} + c := Config{url: url, chanLength: chanLength, reconnectRetryCount: 3, reconnectIntervalMs: 2000} for _, opt := range opts { opt(&c) } @@ -71,3 +74,21 @@ func SetEnableCompression(enableCompression bool) func(*Config) { c.enableCompression = enableCompression } } + +func SetAutoReconnect(reconnect bool) func(*Config) { + return func(c *Config) { + c.autoReconnect = reconnect + } +} + +func SetReconnectIntervalMs(reconnectIntervalMs int) func(*Config) { + return func(c *Config) { + c.reconnectIntervalMs = reconnectIntervalMs + } +} + +func SetReconnectRetryCount(reconnectRetryCount int) func(*Config) { + return func(c *Config) { + c.reconnectRetryCount = reconnectRetryCount + } +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index 49443a5..f22c7a6 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net" "net/url" "sync" "time" @@ -23,17 +24,23 @@ const ( ) type Schemaless struct { - client *client.Client - sendList *list.List - url string - user string - password string - db string - readTimeout time.Duration - lock sync.Mutex - once sync.Once - closeChan chan struct{} - errorHandler func(error) + client *client.Client + sendList *list.List + url string + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + lock sync.Mutex + once sync.Once + closeChan chan struct{} + errorHandler func(error) + dialer *websocket.Dialer + chanLength uint + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int } func NewSchemaless(config *Config) (*Schemaless, error) { @@ -47,21 +54,28 @@ func NewSchemaless(config *Config) (*Schemaless, error) { wsUrl.Path = "/ws" dialer := common.DefaultDialer dialer.EnableCompression = config.enableCompression - ws, _, err := dialer.Dial(wsUrl.String(), nil) + conn, _, err := dialer.Dial(wsUrl.String(), nil) if err != nil { return nil, fmt.Errorf("dial ws error: %s", err) } - ws.EnableWriteCompression(config.enableCompression) - + conn.EnableWriteCompression(config.enableCompression) s := Schemaless{ - client: client.NewClient(ws, config.chanLength), + client: client.NewClient(conn, config.chanLength), sendList: list.New(), - url: config.url, + url: wsUrl.String(), user: config.user, password: config.password, db: config.db, closeChan: make(chan struct{}), errorHandler: config.errorHandler, + dialer: &dialer, + chanLength: config.chanLength, + } + + if config.autoReconnect { + s.autoReconnect = true + s.reconnectIntervalMs = config.reconnectIntervalMs + s.reconnectRetryCount = config.reconnectRetryCount } if config.readTimeout > 0 { @@ -69,21 +83,59 @@ func NewSchemaless(config *Config) (*Schemaless, error) { } if config.writeTimeout > 0 { - s.client.WriteWait = config.writeTimeout + s.writeTimeout = config.writeTimeout } - s.client.ErrorHandler = s.handleError - s.client.TextMessageHandler = s.handleTextMessage - go s.client.ReadPump() - go s.client.WritePump() - - if err = s.connect(); err != nil { + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { return nil, fmt.Errorf("connect ws error: %s", err) } + s.initClient(s.client) return &s, nil } +func (s *Schemaless) initClient(c *client.Client) { + if s.writeTimeout > 0 { + c.WriteWait = s.writeTimeout + } + c.ErrorHandler = s.handleError + c.TextMessageHandler = s.handleTextMessage + + go c.ReadPump() + go c.WritePump() +} + +func (s *Schemaless) reconnect() error { + reconnected := false + for i := 0; i < s.reconnectRetryCount; i++ { + time.Sleep(time.Duration(s.reconnectIntervalMs) * time.Millisecond) + conn, _, err := s.dialer.Dial(s.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(s.dialer.EnableCompression) + if err = connect(conn, s.user, s.password, s.db, s.writeTimeout, s.readTimeout); err != nil { + conn.Close() + continue + } + if s.client != nil { + s.client.Close() + } + c := client.NewClient(conn, s.chanLength) + s.initClient(c) + s.client = c + reconnected = true + break + } + if !reconnected { + if s.client != nil { + s.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl int, reqID int64) error { if reqID == 0 { reqID = common.GetReqID() @@ -102,15 +154,30 @@ func (s *Schemaless) Insert(lines string, protocol int, precision string, ttl in return err } action := &client.WSAction{Action: insertAction, Args: args} - envelope := s.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.client.PutEnvelope(envelope) return err } respBytes, err := s.sendText(uint64(reqID), envelope) if err != nil { - return err + if !s.autoReconnect { + return err + } + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = s.reconnect() + if err != nil { + return err + } + respBytes, err = s.sendText(uint64(reqID), envelope) + if err != nil { + return err + } + } else { + return err + } } var resp schemalessResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -133,13 +200,16 @@ func (s *Schemaless) Close() { }) } -func (s *Schemaless) connect() error { - reqID := uint64(common.GetReqID()) +var ( + ConnectTimeoutErr = errors.New("schemaless connect timeout") +) + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &wsConnectReq{ - ReqID: reqID, - User: s.user, - Password: s.password, - DB: s.db, + ReqID: 0, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -149,14 +219,29 @@ func (s *Schemaless) connect() error { Action: connAction, Args: args, } - envelope := s.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + connectAction, err := client.JsonI.Marshal(action) if err != nil { - s.client.PutEnvelope(envelope) return err } - - respBytes, err := s.sendText(reqID, envelope) + ws.SetWriteDeadline(time.Now().Add(writeTimeout)) + err = ws.WriteMessage(websocket.TextMessage, connectAction) + if err != nil { + return err + } + done := make(chan struct{}) + ctx, cancel := context.WithTimeout(context.Background(), readTimeout) + var respBytes []byte + go func() { + _, respBytes, err = ws.ReadMessage() + close(done) + }() + select { + case <-done: + cancel() + case <-ctx.Done(): + cancel() + return ConnectTimeoutErr + } if err != nil { return err } @@ -182,7 +267,20 @@ func (s *Schemaless) send(reqID uint64, envelope *client.Envelope) ([]byte, erro channel: make(chan []byte, 1), } element := s.addMessageOutChan(channel) - s.client.Send(envelope) + err := s.client.Send(envelope) + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + s.lock.Lock() + s.sendList.Remove(element) + s.lock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), s.readTimeout) defer cancel() select { @@ -259,5 +357,4 @@ func (s *Schemaless) handleError(err error) { if s.errorHandler != nil { s.errorHandler(err) } - s.Close() } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index 83cf47d..dc9caa6 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -1,14 +1,20 @@ package schemaless import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" jsoniter "github.com/json-iterator/go" + "github.com/stretchr/testify/assert" taosErrors "github.com/taosdata/driver-go/v3/errors" "github.com/taosdata/driver-go/v3/ws/client" ) @@ -133,3 +139,103 @@ func before() error { func after() error { return doRequest("drop database test_schemaless_ws") } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port, "--logLevel", "debug") +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 30; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil + time.Sleep(time.Second) +} + +func TestSchemalessReconnect(t *testing.T) { + port := "36041" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + err = doRequest("drop database if exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + err = doRequest("create database if not exists test_schemaless_reconnect") + if err != nil { + t.Fatal(err) + } + s, err := NewSchemaless(NewConfig(fmt.Sprintf("ws://localhost:%s", port), 1, + SetDb("test_schemaless_reconnect"), + SetReadTimeout(3*time.Second), + SetWriteTimeout(3*time.Second), + SetUser("root"), + SetPassword("taosdata"), + //SetEnableCompression(true), + SetErrorHandler(func(err error) { + t.Log(err) + }), + SetAutoReconnect(true), + SetReconnectIntervalMs(2000), + SetReconnectRetryCount(3), + )) + if err != nil { + t.Fatal(err) + } + stopTaosadapter(cmd) + time.Sleep(time.Second * 3) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 10) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } + }() + data := "measurement,host=host1 field1=2i,field2=2.0 1577837300000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837400000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837500000\n" + + "measurement,host=host1 field1=2i,field2=2.0 1577837600000" + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) + err = s.Insert(data, InfluxDBLineProtocol, "ms", 0, 0) + assert.NoError(t, err) +} diff --git a/ws/stmt/config.go b/ws/stmt/config.go index 7eab614..3b533cc 100644 --- a/ws/stmt/config.go +++ b/ws/stmt/config.go @@ -6,22 +6,27 @@ import ( ) type Config struct { - Url string - ChanLength uint - MessageTimeout time.Duration - WriteWait time.Duration - ErrorHandler func(connector *Connector, err error) - CloseHandler func() - User string - Password string - DB string - EnableCompression bool + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string + EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func NewConfig(url string, chanLength uint) *Config { return &Config{ - Url: url, - ChanLength: chanLength, + Url: url, + ChanLength: chanLength, + ReconnectRetryCount: 3, + ReconnectIntervalMs: 2000, } } func (c *Config) SetConnectUser(user string) error { @@ -65,3 +70,15 @@ func (c *Config) SetCloseHandler(f func()) { func (c *Config) SetEnableCompression(enableCompression bool) { c.EnableCompression = enableCompression } + +func (c *Config) SetAutoReconnect(reconnect bool) { + c.AutoReconnect = reconnect +} + +func (c *Config) SetReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *Config) SetReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 0cb2d39..a08cd2e 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "sync" "sync/atomic" @@ -19,17 +20,26 @@ import ( ) type Connector struct { - client *client.Client - requestID uint64 - listLock sync.RWMutex - sendChanList *list.List - writeTimeout time.Duration - readTimeout time.Duration - config *Config - closeOnce sync.Once - closeChan chan struct{} - customErrorHandler func(*Connector, error) - customCloseHandler func() + client *client.Client + requestID uint64 + listLock sync.RWMutex + sendChanList *list.List + writeTimeout time.Duration + readTimeout time.Duration + config *Config + closeOnce sync.Once + closeChan chan struct{} + customErrorHandler func(*Connector, error) + customCloseHandler func() + url string + chanLength uint + dialer *websocket.Dialer + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + user string + password string + db string } var ( @@ -66,15 +76,58 @@ func NewConnector(config *Config) (*Connector, error) { if config.MessageTimeout <= 0 { config.MessageTimeout = common.DefaultMessageTimeout } + err = connect(ws, config.User, config.Password, config.DB, writeTimeout, readTimeout) + if err != nil { + return nil, err + } + wsClient := client.NewClient(ws, config.ChanLength) + connector = &Connector{ + client: wsClient, + requestID: 0, + listLock: sync.RWMutex{}, + sendChanList: list.New(), + writeTimeout: writeTimeout, + readTimeout: readTimeout, + config: config, + closeOnce: sync.Once{}, + closeChan: make(chan struct{}), + customErrorHandler: config.ErrorHandler, + customCloseHandler: config.CloseHandler, + url: u.String(), + dialer: &dialer, + chanLength: config.ChanLength, + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + user: config.User, + password: config.Password, + db: config.DB, + } + connector.initClient(connector.client) + return connector, nil +} + +func (c *Connector) initClient(client *client.Client) { + if c.writeTimeout > 0 { + client.WriteWait = c.writeTimeout + } + client.TextMessageHandler = c.handleTextMessage + client.BinaryMessageHandler = c.handleBinaryMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func connect(ws *websocket.Conn, user string, password string, db string, writeTimeout time.Duration, readTimeout time.Duration) error { req := &ConnectReq{ ReqID: 0, - User: config.User, - Password: config.Password, - DB: config.DB, + User: user, + Password: password, + DB: db, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: STMTConnect, @@ -82,12 +135,12 @@ func NewConnector(config *Config) (*Connector, error) { } connectAction, err := client.JsonI.Marshal(action) if err != nil { - return nil, err + return err } ws.SetWriteDeadline(time.Now().Add(writeTimeout)) err = ws.WriteMessage(websocket.TextMessage, connectAction) if err != nil { - return nil, err + return err } done := make(chan struct{}) ctx, cancel := context.WithTimeout(context.Background(), readTimeout) @@ -101,41 +154,20 @@ func NewConnector(config *Config) (*Connector, error) { cancel() case <-ctx.Done(): cancel() - return nil, ConnectTimeoutErr + return ConnectTimeoutErr } if err != nil { - return nil, err + return err } var resp ConnectResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - wsClient := client.NewClient(ws, config.ChanLength) - wsClient.WriteWait = writeTimeout - connector = &Connector{ - client: wsClient, - requestID: 0, - listLock: sync.RWMutex{}, - sendChanList: list.New(), - writeTimeout: writeTimeout, - readTimeout: readTimeout, - config: config, - closeOnce: sync.Once{}, - closeChan: make(chan struct{}), - customErrorHandler: config.ErrorHandler, - customCloseHandler: config.CloseHandler, + return taosErrors.NewError(resp.Code, resp.Message) } - - wsClient.TextMessageHandler = connector.handleTextMessage - wsClient.BinaryMessageHandler = connector.handleBinaryMessage - wsClient.ErrorHandler = connector.handleError - go wsClient.WritePump() - go wsClient.ReadPump() - return connector, nil + return nil } func (c *Connector) handleTextMessage(message []byte) { @@ -191,7 +223,20 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error channel: make(chan []byte, 1), } element := c.addMessageOutChan(channel) - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.readTimeout) defer cancel() select { @@ -210,6 +255,7 @@ func (c *Connector) send(reqID uint64, envelope *client.Envelope) ([]byte, error func (c *Connector) sendTextWithoutResp(envelope *client.Envelope) { envelope.Type = websocket.TextMessage c.client.Send(envelope) + <-envelope.ErrorChan } func (c *Connector) findOutChanByID(index uint64) *list.Element { @@ -244,13 +290,45 @@ func (c *Connector) handleError(err error) { if c.customErrorHandler != nil { c.customErrorHandler(c, err) } - c.Close() + //c.Close() } func (c *Connector) generateReqID() uint64 { return atomic.AddUint64(&c.requestID, 1) } +func (c *Connector) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + err = connect(conn, c.user, c.password, c.db, c.writeTimeout, c.readTimeout) + if err != nil { + conn.Close() + continue + } + if c.client != nil { + c.client.Close() + } + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + c.client = cl + reconnected = true + break + } + if !reconnected { + if c.client != nil { + c.client.Close() + } + return errors.New("reconnect failed") + } + return nil +} + func (c *Connector) Init() (*Stmt, error) { reqID := c.generateReqID() req := &InitReq{ @@ -264,15 +342,30 @@ func (c *Connector) Init() (*Stmt, error) { Action: STMTInit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + if !c.autoReconnect { + return nil, err + } + var opError *net.OpError + if errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return nil, err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + } else { + return nil, err + } } var resp InitResp err = client.JsonI.Unmarshal(respBytes, &resp) diff --git a/ws/stmt/rows.go b/ws/stmt/rows.go index 78f6c75..5247b55 100644 --- a/ws/stmt/rows.go +++ b/ws/stmt/rows.go @@ -23,6 +23,7 @@ type Rows struct { resultID uint64 block []byte conn *Connector + client *client.Client fieldsCount int fieldsNames []string fieldsTypes []uint8 @@ -88,10 +89,10 @@ func (rs *Rows) taosFetchBlock() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } respBytes, err := rs.conn.sendText(reqID, envelope) @@ -129,10 +130,10 @@ func (rs *Rows) fetchBlock() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } respBytes, err := rs.conn.sendText(rs.resultID, envelope) @@ -160,10 +161,10 @@ func (rs *Rows) freeResult() error { Args: args, } rs.buf.Reset() - envelope := rs.conn.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - rs.conn.client.PutEnvelope(envelope) return err } rs.conn.sendTextWithoutResp(envelope) diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index 9833647..373b763 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -31,10 +31,10 @@ func (s *Stmt) Prepare(sql string) error { Action: STMTPrepare, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -67,10 +67,10 @@ func (s *Stmt) SetTableName(name string) error { Action: STMTSetTableName, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -103,7 +103,8 @@ func (s *Stmt) SetTags(tags *param.Param, bindType *param.ColumnType) error { binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], SetTagsMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) @@ -132,13 +133,13 @@ func (s *Stmt) BindParam(params []*param.Param, bindType *param.ColumnType) erro binary.LittleEndian.PutUint64(reqData, reqID) binary.LittleEndian.PutUint64(reqData[8:], s.id) binary.LittleEndian.PutUint64(reqData[16:], BindMessage) - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) envelope.Msg.Grow(24 + len(block)) envelope.Msg.Write(reqData) envelope.Msg.Write(block) err = client.JsonI.NewEncoder(envelope.Msg).Encode(reqData) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendBinary(reqID, envelope) @@ -170,10 +171,10 @@ func (s *Stmt) AddBatch() error { Action: STMTAddBatch, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -205,10 +206,10 @@ func (s *Stmt) Exec() error { Action: STMTExec, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -245,10 +246,10 @@ func (s *Stmt) UseResult() (*Rows, error) { Action: STMTUseResult, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return nil, err } respBytes, err := s.connector.sendText(reqID, envelope) @@ -266,6 +267,7 @@ func (s *Stmt) UseResult() (*Rows, error) { return &Rows{ buf: &bytes.Buffer{}, conn: s.connector, + client: s.connector.client, resultID: resp.ResultID, fieldsCount: resp.FieldsCount, fieldsNames: resp.FieldsNames, @@ -289,10 +291,10 @@ func (s *Stmt) Close() error { Action: STMTClose, Args: args, } - envelope := s.connector.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - s.connector.client.PutEnvelope(envelope) return err } s.connector.sendTextWithoutResp(envelope) diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 7cd9496..652766e 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -2,11 +2,16 @@ package stmt import ( "database/sql/driver" + "errors" "fmt" "io" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -1017,3 +1022,94 @@ func TestSTMTQuery(t *testing.T) { assert.Equal(t, "tb2", row3[27]) } } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func TestSTMTReconnect(t *testing.T) { + port := "36042" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + if err != nil { + t.Fatal(err) + } + defer func() { + stopTaosadapter(cmd) + }() + config := NewConfig("ws://127.0.0.1:"+port, 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetMessageTimeout(3 * time.Second) + config.SetWriteWait(3 * time.Second) + config.SetEnableCompression(true) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + config.SetAutoReconnect(true) + config.SetReconnectRetryCount(3) + config.SetReconnectIntervalMs(2000) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + stmt, err := connector.Init() + assert.NoError(t, err) + stmt.Close() + stopTaosadapter(cmd) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 3) + err = startTaosadapter(cmd, port) + startChan <- struct{}{} + if err != nil { + t.Error(err) + return + } + }() + stmt, err = connector.Init() + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + stmt, err = connector.Init() + assert.NoError(t, err) + stmt.Close() +} diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 99c96a3..e119dcf 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -20,6 +20,9 @@ type config struct { SnapshotEnable string WithTableName string EnableCompression bool + AutoReconnect bool + ReconnectIntervalMs int + ReconnectRetryCount int } func newConfig(url string, chanLength uint) *config { @@ -84,3 +87,15 @@ func (c *config) setWithTableName(withTableName string) { func (c *config) setEnableCompression(enableCompression bool) { c.EnableCompression = enableCompression } + +func (c *config) setAutoReconnect(autoReconnect bool) { + c.AutoReconnect = autoReconnect +} + +func (c *config) setReconnectIntervalMs(reconnectIntervalMs int) { + c.ReconnectIntervalMs = reconnectIntervalMs +} + +func (c *config) setReconnectRetryCount(reconnectRetryCount int) { + c.ReconnectRetryCount = reconnectRetryCount +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index b4f3e8b..64bb1ed 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net" "net/url" "strconv" "sync" @@ -23,27 +24,33 @@ import ( ) type Consumer struct { - client *client.Client - requestID uint64 - err error - dataParser *parser.TMQRawDataParser - listLock sync.RWMutex - sendChanList *list.List - messageTimeout time.Duration - autoCommit bool - autoCommitInterval time.Duration - nextAutoCommitTime time.Time - url string - user string - password string - groupID string - clientID string - offsetRest string - snapshotEnable string - withTableName string - closeOnce sync.Once - closeChan chan struct{} - topics []string + client *client.Client + requestID uint64 + err error + dataParser *parser.TMQRawDataParser + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + autoCommit bool + autoCommitInterval time.Duration + nextAutoCommitTime time.Time + url string + user string + password string + groupID string + clientID string + offsetRest string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} + topics []string + autoReconnect bool + reconnectIntervalMs int + reconnectRetryCount int + chanLength uint + writeWait time.Duration + dialer *websocket.Dialer } type IndexedChan struct { @@ -94,34 +101,75 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { wsClient := client.NewClient(ws, config.ChanLength) consumer := &Consumer{ - client: wsClient, - requestID: 0, - sendChanList: list.New(), - messageTimeout: config.MessageTimeout, - url: config.Url, - user: config.User, - password: config.Password, - groupID: config.GroupID, - clientID: config.ClientID, - offsetRest: config.OffsetRest, - autoCommit: autoCommit, - autoCommitInterval: autoCommitInterval, - snapshotEnable: config.SnapshotEnable, - withTableName: config.WithTableName, - closeChan: make(chan struct{}), - dataParser: parser.NewTMQRawDataParser(), - } - if config.WriteWait > 0 { - wsClient.WriteWait = config.WriteWait - } - wsClient.BinaryMessageHandler = consumer.handleBinaryMessage - wsClient.TextMessageHandler = consumer.handleTextMessage - wsClient.ErrorHandler = consumer.handleError - go wsClient.WritePump() - go wsClient.ReadPump() + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: u.String(), + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: autoCommit, + autoCommitInterval: autoCommitInterval, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), + dataParser: parser.NewTMQRawDataParser(), + autoReconnect: config.AutoReconnect, + reconnectIntervalMs: config.ReconnectIntervalMs, + reconnectRetryCount: config.ReconnectRetryCount, + chanLength: config.ChanLength, + writeWait: config.WriteWait, + dialer: &dialer, + } + consumer.initClient(consumer.client) return consumer, nil } +func (c *Consumer) initClient(client *client.Client) { + if c.writeWait > 0 { + client.WriteWait = c.writeWait + } + client.BinaryMessageHandler = c.handleBinaryMessage + client.TextMessageHandler = c.handleTextMessage + client.ErrorHandler = c.handleError + go client.WritePump() + go client.ReadPump() +} + +func (c *Consumer) reconnect() error { + reconnected := false + for i := 0; i < c.reconnectRetryCount; i++ { + time.Sleep(time.Duration(c.reconnectIntervalMs) * time.Millisecond) + conn, _, err := c.dialer.Dial(c.url, nil) + if err != nil { + continue + } + conn.EnableWriteCompression(c.dialer.EnableCompression) + cl := client.NewClient(conn, c.chanLength) + c.initClient(cl) + if c.client != nil { + c.client.Close() + } + c.client = cl + if len(c.topics) > 0 { + err = c.doSubscribe(c.topics, false) + if err != nil { + c.client.Close() + continue + } + } + reconnected = true + break + } + if !reconnected { + return errors.New("reconnect failed") + } + return nil +} + func configMapToConfig(m *tmq.ConfigMap) (*config, error) { url, err := m.Get("ws.url", "") if err != nil { @@ -183,6 +231,18 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + autoReconnect, err := m.Get("ws.autoReconnect", false) + if err != nil { + return nil, err + } + reconnectIntervalMs, err := m.Get("ws.reconnectIntervalMs", int(2000)) + if err != nil { + return nil, err + } + reconnectRetryCount, err := m.Get("ws.reconnectRetryCount", int(3)) + if err != nil { + return nil, err + } config := newConfig(url.(string), chanLen.(uint)) err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { @@ -202,6 +262,9 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { config.setSnapshotEnable(enableSnapshot.(string)) config.setWithTableName(withTableName.(string)) config.setEnableCompression(enableCompression.(bool)) + config.setAutoReconnect(autoReconnect.(bool)) + config.setReconnectIntervalMs(reconnectIntervalMs.(int)) + config.setReconnectRetryCount(reconnectRetryCount.(int)) return config, nil } @@ -240,8 +303,9 @@ func (c *Consumer) handleBinaryMessage(message []byte) { } func (c *Consumer) handleError(err error) { - c.err = &WSError{err: err} - c.Close() + if !c.autoReconnect { + c.err = &WSError{err: err} + } } func (c *Consumer) generateReqID() uint64 { @@ -302,17 +366,26 @@ const ( var ClosedErr = errors.New("connection closed") func (c *Consumer) sendText(reqID uint64, envelope *client.Envelope) ([]byte, error) { - if !c.client.IsRunning() { - c.client.PutEnvelope(envelope) - return nil, ClosedErr - } channel := &IndexedChan{ index: reqID, channel: make(chan []byte, 1), } element := c.addMessageOutChan(channel) envelope.Type = websocket.TextMessage - c.client.Send(envelope) + err := c.client.Send(envelope) + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } + err = <-envelope.ErrorChan + if err != nil { + c.listLock.Lock() + c.sendChanList.Remove(element) + c.listLock.Unlock() + return nil, err + } ctx, cancel := context.WithTimeout(context.Background(), c.messageTimeout) defer cancel() select { @@ -335,6 +408,10 @@ func (c *Consumer) Subscribe(topic string, rebalanceCb RebalanceCb) error { } func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) error { + return c.doSubscribe(topics, c.autoReconnect) +} + +func (c *Consumer) doSubscribe(topics []string, reconnect bool) error { if c.err != nil { return c.err } @@ -359,15 +436,30 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err Action: TMQSubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return err + if !reconnect { + return err + } + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return err + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return err + } + } else { + return err + } } var resp SubscribeResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -410,15 +502,30 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { Action: TMQPoll, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return tmq.NewTMQErrorWithErr(err) } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return tmq.NewTMQErrorWithErr(err) + if !c.autoReconnect { + return tmq.NewTMQErrorWithErr(err) + } + var opError *net.OpError + if errors.Is(err, ClosedErr) || errors.Is(err, client.ClosedError) || errors.As(err, &opError) { + err = c.reconnect() + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + respBytes, err = c.sendText(reqID, envelope) + if err != nil { + return tmq.NewTMQErrorWithErr(err) + } + } else { + return tmq.NewTMQErrorWithErr(err) + } } var resp PollResp err = client.JsonI.Unmarshal(respBytes, &resp) @@ -510,10 +617,10 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { Action: TMQFetchJsonMeta, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -550,10 +657,10 @@ func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { Action: TMQFetchRaw, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -603,10 +710,10 @@ func (c *Consumer) doCommit() error { Action: TMQCommit, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -640,10 +747,10 @@ func (c *Consumer) Unsubscribe() error { Action: TMQUnsubscribe, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -679,10 +786,10 @@ func (c *Consumer) Assignment() (partitions []tmq.TopicPartition, err error) { Action: TMQGetTopicAssignment, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -729,10 +836,10 @@ func (c *Consumer) Seek(partition tmq.TopicPartition, ignoredTimeoutMs int) erro Action: TMQSeek, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return err } respBytes, err := c.sendText(reqID, envelope) @@ -771,10 +878,10 @@ func (c *Consumer) Committed(partitions []tmq.TopicPartition, timeoutMs int) (of Action: TMQCommitted, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -803,6 +910,8 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti if c.err != nil { return nil, c.err } + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) for i := 0; i < len(offsets); i++ { reqID := c.generateReqID() req := &CommitOffsetReq{ @@ -819,10 +928,9 @@ func (c *Consumer) CommitOffsets(offsets []tmq.TopicPartition) ([]tmq.TopicParti Action: TMQCommitOffset, Args: args, } - envelope := c.client.GetEnvelope() + envelope.Reset() err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) @@ -862,10 +970,10 @@ func (c *Consumer) Position(partitions []tmq.TopicPartition) (offsets []tmq.Topi Action: TMQPosition, Args: args, } - envelope := c.client.GetEnvelope() + envelope := client.GlobalEnvelopePool.Get() + defer client.GlobalEnvelopePool.Put(envelope) err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { - c.client.PutEnvelope(envelope) return nil, err } respBytes, err := c.sendText(reqID, envelope) diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 6407984..37dd34b 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -1,10 +1,15 @@ package tmq import ( + "errors" "fmt" "io/ioutil" "net/http" + "os" + "os/exec" + "runtime" "strings" + "syscall" "testing" "time" @@ -819,3 +824,154 @@ func TestMeta(t *testing.T) { } } } + +func newTaosadapter(port string) *exec.Cmd { + command := "taosadapter" + if runtime.GOOS == "windows" { + command = "C:\\TDengine\\taosadapter.exe" + + } + return exec.Command(command, "--port", port) +} + +func startTaosadapter(cmd *exec.Cmd, port string) error { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + return err + } + for i := 0; i < 10; i++ { + time.Sleep(time.Millisecond * 100) + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%s/-/ping", port)) + if err != nil { + continue + } + resp.Body.Close() + time.Sleep(time.Second) + return nil + } + return errors.New("taosadapter start failed") +} + +func stopTaosadapter(cmd *exec.Cmd) { + if cmd.Process == nil { + return + } + cmd.Process.Signal(syscall.SIGINT) + cmd.Process.Wait() + cmd.Process = nil +} + +func prepareSubReconnectEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + "create database test_ws_tmq_sub_reconnect vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_sub_reconnect_topic as database test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanSubReconnectEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_sub_reconnect_topic", + "drop database if exists test_ws_tmq_sub_reconnect", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestSubscribeReconnect(t *testing.T) { + port := "36043" + cmd := newTaosadapter(port) + err := startTaosadapter(cmd, port) + assert.NoError(t, err) + defer func() { + stopTaosadapter(cmd) + }() + prepareSubReconnectEnv() + defer cleanSubReconnectEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:" + port, + "ws.message.channelLen": uint(0), + "ws.message.timeout": time.Second * 5, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + "ws.autoReconnect": true, + "ws.reconnectIntervalMs": 3000, + "ws.reconnectRetryCount": 3, + }) + assert.NoError(t, err) + stopTaosadapter(cmd) + time.Sleep(time.Second) + startChan := make(chan struct{}) + go func() { + time.Sleep(time.Second * 3) + err = startTaosadapter(cmd, port) + if err != nil { + t.Error(err) + return + } + startChan <- struct{}{} + }() + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.Error(t, err) + <-startChan + time.Sleep(time.Second) + err = consumer.Subscribe("test_ws_tmq_sub_reconnect_topic", nil) + assert.NoError(t, err) + doRequest("create table test_ws_tmq_sub_reconnect.st(ts timestamp,v int) tags (cn binary(20))") + doRequest("create table test_ws_tmq_sub_reconnect.t1 using test_ws_tmq_sub_reconnect.st tags ('t1')") + doRequest("insert into test_ws_tmq_sub_reconnect.t1 values (now,1)") + stopTaosadapter(cmd) + go func() { + time.Sleep(time.Second * 3) + startTaosadapter(cmd, port) + startChan <- struct{}{} + }() + time.Sleep(time.Second) + event := consumer.Poll(500) + assert.NotNil(t, event) + _, ok := event.(tmq.Error) + assert.True(t, ok) + <-startChan + haveMessage := false + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + t.Log(e) + assert.Equal(t, "test_ws_tmq_sub_reconnect", e.DBName()) + haveMessage = true + break + default: + t.Log(e) + } + } + assert.True(t, haveMessage) +}