Skip to content

Commit

Permalink
修正IO错误处理
Browse files Browse the repository at this point in the history
  • Loading branch information
blusewang committed Dec 14, 2023
1 parent a543ab2 commit 3d6aafd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 62 deletions.
3 changes: 1 addition & 2 deletions internal/app/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"encoding/json"
"fmt"
"github.com/blusewang/pg/v2/internal/client/frame"
"io"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -104,7 +103,7 @@ func (r *Rows) Next(dest []driver.Value) error {
} else if r.position < 0 || r.position >= rowsLen {
return fmt.Errorf("pg_rows rows length is %v but position is %v", rowsLen, r.position)
} else if r.position == rowsLen {
return io.EOF
return fmt.Errorf("pg_rows rows length is %v but position is %v", rowsLen, r.position)
}
for i, v := range r.rows[r.position].DataArr {
dest[i] = r.data2Value(v, r.columns.Columns[i])
Expand Down
77 changes: 20 additions & 57 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"fmt"
"github.com/blusewang/pg/v2/internal/client/frame"
"github.com/blusewang/pg/v2/internal/client/scram"
"io"
"net"
"os"
"strconv"
"time"
)

Expand Down Expand Up @@ -65,10 +63,6 @@ type Client struct {
notificationHandler NotificationHandler // Listen 消息
}

func (c *Client) TestConn() ([]byte, error) {
return c.reader.Peek(1024)
}

func (c *Client) Connect(ctx context.Context, dsn DataSourceName) (err error) {
c.ctx = ctx
c.Dsn = dsn
Expand Down Expand Up @@ -252,14 +246,14 @@ func (c *Client) Startup() (err error) {

func (c *Client) QueryNoArgs(query string) (res SimpleQueryResponse, err error) {
if err = c.writer.Send(frame.NewSimpleQuery(query)); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}

for {
var f *frame.Data
f, err = c.reader.Receive()
if err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
switch f.Type() {
case frame.TypeDataRow:
Expand All @@ -284,23 +278,23 @@ func (c *Client) QueryNoArgs(query string) (res SimpleQueryResponse, err error)

func (c *Client) Parse(name, query string) (res ParseResponse, err error) {
if err = c.writer.Buff(frame.NewParse(name, query)); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Buff(frame.NewDescribe(name)); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Buff(frame.NewSync()); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Flush(); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}

for {
var f *frame.Data
f, err = c.reader.Receive()
if err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}

switch f.Type() {
Expand All @@ -324,23 +318,23 @@ func (c *Client) Parse(name, query string) (res ParseResponse, err error) {

func (c *Client) BindExec(name string, args []driver.Value) (res BindExecResponse, err error) {
if err = c.writer.Buff(frame.NewBind(name, args)); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Buff(frame.NewExecute()); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Buff(frame.NewSync()); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
if err = c.writer.Flush(); err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}

for {
var f *frame.Data
f, err = c.reader.Receive()
if err != nil {
return res, c.handleError(err)
return res, c.handleIOError(err)
}
switch f.Type() {
case frame.TypeNoticeResponse:
Expand All @@ -365,19 +359,19 @@ func (c *Client) BindExec(name string, args []driver.Value) (res BindExecRespons

func (c *Client) CloseParse(name string) (err error) {
if err = c.writer.Buff(frame.NewCloseStat(name)); err != nil {
return c.handleError(err)
return c.handleIOError(err)
}
if err = c.writer.Buff(frame.NewSync()); err != nil {
return c.handleError(err)
return c.handleIOError(err)
}
if err = c.writer.Flush(); err != nil {
return c.handleError(err)
return c.handleIOError(err)
}
for {
var d *frame.Data
d, err = c.reader.Receive()
if err != nil {
return c.handleError(err)
return c.handleIOError(err)
}
switch d.Type() {
case frame.TypeReadyForQuery:
Expand Down Expand Up @@ -410,7 +404,6 @@ func (c *Client) GetNotification() (pid uint32, channel, message string, err err

}
}

}

// CancelRequest 建立新连接,使用PID+口令从新连接中发出指令
Expand All @@ -427,28 +420,17 @@ func (c *Client) Terminate() (err error) {
return c.cn.Close()
}

func (c *Client) Close() (err error) {
func (c *Client) CloseConn() (err error) {
return c.cn.Close()
}

func (c *Client) IsInTransaction() bool {
return c.status == frame.TransactionStatusIdleInTransaction || c.status == frame.TransactionStatusInFailedTransaction
}

func (c *Client) handleError(err error) error {
if err == io.EOF {
c.ConnectStatus = ConnectStatusDisconnected
go func() {
recover()
_ = c.cn.Close()
if v, has := c.Dsn.Parameter["reconnect"]; has {
c.ConnectStatus = ConnectStatusConnecting
n, _ := strconv.Atoi(v)
time.Sleep(time.Second * time.Duration(n))
c.reconnect()
}
}()
}
func (c *Client) handleIOError(err error) error {
c.ConnectStatus = ConnectStatusDisconnected
_ = c.cn.Close()
return err
}

Expand All @@ -459,28 +441,9 @@ func (c *Client) handlePgError(d *frame.Data) error {
if e.Error.Fail == "FATAL" || e.Error.Fail == "PANIC" {
// 这两种错误需立即断开连接
go func() {
recover()
_ = c.cn.Close()
c.ConnectStatus = ConnectStatusDisconnected
if v, has := c.Dsn.Parameter["reconnect"]; has {
c.ConnectStatus = ConnectStatusConnecting
n, _ := strconv.Atoi(v)
time.Sleep(time.Second * time.Duration(n))
c.reconnect()
}
}()
}
return e.Error
}

func (c *Client) reconnect() {
if err := c.Connect(c.ctx, c.Dsn); err != nil {
return
}
if err := c.AutoSSL(); err != nil {
return
}
if err := c.Startup(); err != nil {
return
}
}
6 changes: 3 additions & 3 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ func NewListener(ctx context.Context, dsnString string) (Listener, error) {
return nil, err
}
if err = c.Connect(ctx, dsn); err != nil {
_ = c.Close()
_ = c.CloseConn()
return nil, err
}
if err = c.AutoSSL(); err != nil {
_ = c.Close()
_ = c.CloseConn()
return nil, err
}
if err = c.Startup(); err != nil {
_ = c.Close()
_ = c.CloseConn()
return nil, err
}
return c, nil
Expand Down

0 comments on commit 3d6aafd

Please sign in to comment.