Skip to content

Commit

Permalink
Merge pull request #265 from yinheli/yinheli/enh-ws-conn-228
Browse files Browse the repository at this point in the history
enh: taosConn add atomic closed flag
  • Loading branch information
huskar-t authored May 20, 2024
2 parents cff3f5e + c7c8d2e commit 2532a8e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
17 changes: 17 additions & 0 deletions taosWS/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type taosConn struct {
writeTimeout time.Duration
cfg *config
endpoint string
closed atomic.Bool // set when conn is closed,
}

func (tc *taosConn) generateReqID() uint64 {
Expand Down Expand Up @@ -100,6 +101,10 @@ func (tc *taosConn) Begin() (driver.Tx, error) {
}

func (tc *taosConn) Close() (err error) {
if tc.closed.Swap(true) {
return nil
}

if tc.client != nil {
err = tc.client.Close()
}
Expand All @@ -110,6 +115,9 @@ func (tc *taosConn) Close() (err error) {
}

func (tc *taosConn) Prepare(query string) (driver.Stmt, error) {
if tc.closed.Load() {
return nil, driver.ErrBadConn
}
stmtID, err := tc.stmtInit()
if err != nil {
return nil, err
Expand Down Expand Up @@ -410,6 +418,9 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver
}

func (tc *taosConn) execCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if tc.closed.Load() {
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand Down Expand Up @@ -463,6 +474,9 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive
}

func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if tc.closed.Load() {
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !tc.cfg.interpolateParams {
return nil, driver.ErrSkip
Expand Down Expand Up @@ -521,6 +535,9 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name
}

func (tc *taosConn) Ping(ctx context.Context) (err error) {
if tc.closed.Load() {
return driver.ErrBadConn
}
return nil
}

Expand Down
26 changes: 26 additions & 0 deletions taosWS/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,29 @@ func Test_formatBytes(t *testing.T) {
})
}
}

func TestBadConnection(t *testing.T) {
defer func() {
if r := recover(); r != nil {
// bad connection should not panic
t.Fatalf("panic: %v", r)
}
}()

cfg, err := parseDSN(dataSourceName)
if err != nil {
t.Fatalf("parseDSN error: %v", err)
}
conn, err := newTaosConn(cfg)
if err != nil {
t.Fatalf("newTaosConn error: %v", err)
}

// to test bad connection, we manually close the connection
conn.Close()

_, err = conn.Query("select 1", nil)
if err == nil {
t.Fatalf("query should fail")
}
}
9 changes: 9 additions & 0 deletions taosWS/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ type Stmt struct {
}

func (stmt *Stmt) Close() error {
if stmt.conn == nil || stmt.conn.closed.Load() {
return driver.ErrBadConn
}
err := stmt.conn.stmtClose(stmt.stmtID)
stmt.buffer.Reset()
stmt.conn = nil
Expand All @@ -42,6 +45,9 @@ func (stmt *Stmt) NumInput() int {
}

func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.conn.closed.Load() {
return nil, driver.ErrBadConn
}
if stmt.conn == nil {
return nil, driver.ErrBadConn
}
Expand All @@ -68,6 +74,9 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) {
}

func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) {
if stmt.conn.closed.Load() {
return nil, driver.ErrBadConn
}
if stmt.conn == nil {
return nil, driver.ErrBadConn
}
Expand Down

0 comments on commit 2532a8e

Please sign in to comment.