diff --git a/README.md b/README.md index 1c96d776..9ed9e8ed 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ go get -u github.com/henrylee2cn/teleport - Server and client are peer-to-peer, have the same API method - Packet contains both Header and Body two parts - Support for customizing head and body coding types separately, e.g `JSON` `Protobuf` `string` +- Support custom communication protocol - Body supports gzip compression - Header contains the status code and its description text - Support push, pull, reply and other means of communication @@ -67,30 +68,16 @@ go get -u github.com/henrylee2cn/teleport Peer -> Connection -> Socket -> Session -> Context ``` - ### 4.3 Packet -``` -HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body -``` - -**Notes:** - -- `HeaderLength`: uint32, 4 bytes, big endian -- `HeaderCodecId`: uint8, 1 byte -- `Header`: header bytes -- `BodyLength`: uint32, 4 bytes, big endian - * may be 0, meaning that the `Body` is empty and does not indicate the `BodyCodecId` - * may be 1, meaning that the `Body` is empty but indicates the `BodyCodecId` -- `BodyCodecId`: uint8, 1 byte -- `Body`: body bytes +The contents of every one packet: ```go type Packet struct { - // HeaderCodec header codec name - HeaderCodec string `json:"header_codec"` - // BodyCodec body codec name - BodyCodec string `json:"body_codec"` + // HeaderCodec header codec string + HeaderCodec string + // BodyCodec body codec string + BodyCodec string // header content Header *Header `json:"header"` // body content @@ -104,7 +91,7 @@ type Packet struct { } ``` -### 4.4 Header +Among the contents of the header: ```go type Header struct { @@ -123,6 +110,60 @@ type Header struct { } ``` +### 4.4 Protocol + +The default socket communication protocol: + +``` +HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body +``` + +**Notes:** + +- `HeaderLength`: uint32, 4 bytes, big endian +- `HeaderCodecId`: uint8, 1 byte +- `Header`: header bytes +- `BodyLength`: uint32, 4 bytes, big endian + * may be 0, meaning that the `Body` is empty and does not indicate the `BodyCodecId` + * may be 1, meaning that the `Body` is empty but indicates the `BodyCodecId` +- `BodyCodecId`: uint8, 1 byte +- `Body`: body bytes + +You can customize your own communication protocol by implementing the interface: + +```go +// Protocol socket communication protocol +type Protocol interface { + // WritePacket writes header and body to the connection. + WritePacket( + packet *Packet, + destWriter *utils.BufioWriter, + tmpCodecWriterGetter func(codecName string) (*TmpCodecWriter, error), + isActiveClosed func() bool, + ) error + + // ReadPacket reads header and body from the connection. + ReadPacket( + packet *Packet, + bodyAdapter func() interface{}, + srcReader *utils.BufioReader, + codecReaderGetter func(codecId byte) (*CodecReader, error), + isActiveClosed func() bool, + ) error +} +``` + +Next, you can specify the communication protocol in the following ways: + +```go +func SetDefaultProtocol(socket.Protocol) +func (*Peer) ServeConn(conn net.Conn, protocol ...socket.Protocol) Session +func (*Peer) DialContext(ctx context.Context, addr string, protocol ...socket.Protocol) (Session, error) +func (*Peer) Dial(addr string, protocol ...socket.Protocol) (Session, error) +func (*Peer) Listen(protocol ...socket.Protocol) error +``` + + ## 5. Usage - Create a server or client peer @@ -151,7 +192,7 @@ var peer = tp.NewPeer(cfg) peer.Listen() // It can also be used as a client at the same time -var sess, err = peer.Dial("127.0.0.1:8080", "peerid-client") +var sess, err = peer.Dial("127.0.0.1:8080") if err != nil { tp.Panicf("%v", err) } @@ -367,7 +408,7 @@ func main() { peer.PushRouter.Reg(new(Push)) { - var sess, err = peer.Dial("127.0.0.1:9090", "simple_server:9090") + var sess, err = peer.Dial("127.0.0.1:9090") if err != nil { tp.Panicf("%v", err) } diff --git a/README_ZH.md b/README_ZH.md index 3a1e1a05..afd232d8 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -38,6 +38,7 @@ go get -u github.com/henrylee2cn/teleport - 服务器和客户端之间对等通信,两者API方法基本一致 - 底层通信数据包包含`Header`和`Body`两部分 - 支持单独定制`Header`和`Body`编码类型,例如`JSON` `Protobuf` `string` +- 支持定制通信协议 - `Body`支持gzip压缩 - `Header`包含状态码及其描述文本 - 支持推、拉、回复等通信模式 @@ -68,30 +69,16 @@ go get -u github.com/henrylee2cn/teleport Peer -> Connection -> Socket -> Session -> Context ``` +### 4.3 包内容 -### 4.3 数据包 - -``` -HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body -``` - -**注意:** - -- `HeaderLength`: uint32, 4 bytes, big endian -- `HeaderCodecId`: uint8, 1 byte -- `Header`: header bytes -- `BodyLength`: uint32, 4 bytes, big endian - * 可能为0,表示`Body`为空且不指明`BodyCodecId` - * 可能为1,表示`Body`为空但是指明`BodyCodecId` -- `BodyCodecId`: uint8, 1 byte -- `Body`: body bytes +每个数据包的内容如下: ```go type Packet struct { - // HeaderCodec header codec name - HeaderCodec string `json:"header_codec"` - // BodyCodec body codec name - BodyCodec string `json:"body_codec"` + // HeaderCodec header codec string + HeaderCodec string + // BodyCodec body codec string + BodyCodec string // header content Header *Header `json:"header"` // body content @@ -105,7 +92,7 @@ type Packet struct { } ``` -### 4.4 头信息 +其中头部内容为: ```go type Header struct { @@ -124,6 +111,62 @@ type Header struct { } ``` + +### 4.4 通信协议 + +默认的通信协议: + +``` +HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body +``` + +**注意:** + +- `HeaderLength`: uint32, 4 bytes, big endian +- `HeaderCodecId`: uint8, 1 byte +- `Header`: header bytes +- `BodyLength`: uint32, 4 bytes, big endian + * 可能为0,表示`Body`为空且不指明`BodyCodecId` + * 可能为1,表示`Body`为空但是指明`BodyCodecId` +- `BodyCodecId`: uint8, 1 byte +- `Body`: body bytes + + +你可以通过实现接口的方法定制自己的通信协议: + +```go +// Protocol socket communication protocol +type Protocol interface { + // WritePacket writes header and body to the connection. + WritePacket( + packet *Packet, + destWriter *utils.BufioWriter, + tmpCodecWriterGetter func(codecName string) (*TmpCodecWriter, error), + isActiveClosed func() bool, + ) error + + // ReadPacket reads header and body from the connection. + ReadPacket( + packet *Packet, + bodyAdapter func() interface{}, + srcReader *utils.BufioReader, + codecReaderGetter func(codecId byte) (*CodecReader, error), + isActiveClosed func() bool, + ) error +} +``` + +接着,你可以使用以下任意方式指定自己的通信协议: + +```go +func SetDefaultProtocol(socket.Protocol) +func (*Peer) ServeConn(conn net.Conn, protocol ...socket.Protocol) Session +func (*Peer) DialContext(ctx context.Context, addr string, protocol ...socket.Protocol) (Session, error) +func (*Peer) Dial(addr string, protocol ...socket.Protocol) (Session, error) +func (*Peer) Listen(protocol ...socket.Protocol) error +``` + + ## 5. 用法 - 创建一个Peer端点,服务端或客户端 @@ -152,7 +195,7 @@ var peer = tp.NewPeer(cfg) peer.Listen() // It can also be used as a client at the same time -var sess, err = peer.Dial("127.0.0.1:8080", "peerid-client") +var sess, err = peer.Dial("127.0.0.1:8080") if err != nil { tp.Panicf("%v", err) } @@ -368,7 +411,7 @@ func main() { peer.PushRouter.Reg(new(Push)) { - var sess, err = peer.Dial("127.0.0.1:9090", "simple_server:9090") + var sess, err = peer.Dial("127.0.0.1:9090") if err != nil { tp.Panicf("%v", err) } diff --git a/peer.go b/peer.go index cb169e63..025ed72b 100644 --- a/peer.go +++ b/peer.go @@ -24,6 +24,7 @@ import ( "github.com/henrylee2cn/goutil" "github.com/henrylee2cn/goutil/errors" + "github.com/henrylee2cn/teleport/socket" ) // Peer peer which is server or client. @@ -89,14 +90,14 @@ func (p *Peer) GetSession(sessionId string) (Session, bool) { } // ServeConn serves the connection and returns a session. -func (p *Peer) ServeConn(conn net.Conn, id ...string) Session { - var session = newSession(p, conn, id...) +func (p *Peer) ServeConn(conn net.Conn, protocol ...socket.Protocol) Session { + var session = newSession(p, conn, protocol) p.sessHub.Set(session) return session } // Dial connects with the peer of the destination address. -func (p *Peer) Dial(addr string, id ...string) (Session, error) { +func (p *Peer) Dial(addr string, protocol ...socket.Protocol) (Session, error) { var conn, err = net.DialTimeout("tcp", addr, p.defaultDialTimeout) if err != nil { return nil, err @@ -104,7 +105,7 @@ func (p *Peer) Dial(addr string, id ...string) (Session, error) { if p.tlsConfig != nil { conn = tls.Client(conn, p.tlsConfig) } - var sess = newSession(p, conn, id...) + var sess = newSession(p, conn, protocol) if err = p.pluginContainer.PostDial(sess); err != nil { sess.Close() return nil, err @@ -117,7 +118,7 @@ func (p *Peer) Dial(addr string, id ...string) (Session, error) { // DialContext connects with the peer of the destination address, // using the provided context. -func (p *Peer) DialContext(ctx context.Context, addr string, id ...string) (Session, error) { +func (p *Peer) DialContext(ctx context.Context, addr string, protocol ...socket.Protocol) (Session, error) { var d net.Dialer var conn, err = d.DialContext(ctx, "tcp", addr) if err != nil { @@ -126,7 +127,7 @@ func (p *Peer) DialContext(ctx context.Context, addr string, id ...string) (Sess if p.tlsConfig != nil { conn = tls.Client(conn, p.tlsConfig) } - var sess = newSession(p, conn, id...) + var sess = newSession(p, conn, protocol) if err = p.pluginContainer.PostDial(sess); err != nil { sess.Close() return nil, err @@ -141,7 +142,7 @@ func (p *Peer) DialContext(ctx context.Context, addr string, id ...string) (Sess var ErrListenClosed = errors.New("listener is closed") // Listen turns on the listening service. -func (p *Peer) Listen() error { +func (p *Peer) Listen(protocol ...socket.Protocol) error { var ( wg sync.WaitGroup count = len(p.listenAddrs) @@ -151,7 +152,7 @@ func (p *Peer) Listen() error { for _, addr := range p.listenAddrs { go func(addr string) { defer wg.Done() - errCh <- p.listen(addr) + errCh <- p.listen(addr, protocol) }(addr) } wg.Wait() @@ -164,7 +165,7 @@ func (p *Peer) Listen() error { return errs } -func (p *Peer) listen(addr string) error { +func (p *Peer) listen(addr string, protocols []socket.Protocol) error { var lis, err = listen(addr, p.tlsConfig) if err != nil { Fatalf("%v", err) @@ -219,7 +220,7 @@ func (p *Peer) listen(addr string) error { time.Sleep(time.Second) goto TRYGO } - }(newSession(p, rw)) + }(newSession(p, rw, protocols)) } } diff --git a/samples/ab/frame_client_ab.go b/samples/ab/frame_client_ab.go index d08ab5c1..5e85dd5a 100644 --- a/samples/ab/frame_client_ab.go +++ b/samples/ab/frame_client_ab.go @@ -30,7 +30,7 @@ func main() { var peer = tp.NewPeer(cfg) - var sess, err = peer.Dial("127.0.0.1:9090", "simple_server:9090") + var sess, err = peer.Dial("127.0.0.1:9090") if err != nil { tp.Panicf("%v", err) } diff --git a/samples/simple/client.go b/samples/simple/client.go index 83858bf7..b97e84f6 100644 --- a/samples/simple/client.go +++ b/samples/simple/client.go @@ -27,7 +27,7 @@ func main() { peer.PushRouter.Reg(new(Push)) { - var sess, err = peer.Dial("127.0.0.1:9090", "simple_server:9090") + var sess, err = peer.Dial("127.0.0.1:9090") if err != nil { tp.Panicf("%v", err) } diff --git a/session.go b/session.go index b4f6101f..cc22de46 100644 --- a/session.go +++ b/session.go @@ -122,12 +122,12 @@ var ( _ ForeSession = new(session) ) -func newSession(peer *Peer, conn net.Conn, id ...string) *session { +func newSession(peer *Peer, conn net.Conn, protocols []socket.Protocol) *session { var s = &session{ peer: peer, pullRouter: peer.PullRouter, pushRouter: peer.PushRouter, - socket: socket.NewSocket(conn, id...), + socket: socket.NewSocket(conn, protocols...), pullCmdMap: goutil.RwMap(), readTimeout: peer.defaultReadTimeout, writeTimeout: peer.defaultWriteTimeout, diff --git a/socket/README.md b/socket/README.md index ea100e23..6b565576 100644 --- a/socket/README.md +++ b/socket/README.md @@ -11,25 +11,13 @@ A concise, powerful and high-performance TCP connection socket. - Header and Body can use different coding types - Body supports gzip compression - Header contains the status code and its description text +- Support custom communication protocol - Each socket is assigned an id - Provides `Socket` hub, `Socket` pool and `*Packet` stack ## Packet -``` -HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body -``` - -**Notes:** - -- `HeaderLength`: uint32, 4 bytes, big endian -- `HeaderCodecId`: uint8, 1 byte -- `Header`: header bytes -- `BodyLength`: uint32, 4 bytes, big endian - * may be 0, meaning that the `Body` is empty and does not indicate the `BodyCodecId` - * may be 1, meaning that the `Body` is empty but indicates the `BodyCodecId` -- `BodyCodecId`: uint8, 1 byte -- `Body`: body bytes +The contents of every one packet: ```go type Packet struct { @@ -50,7 +38,7 @@ type Packet struct { } ``` -## Header +Among the contents of the header: ```go type Header struct { @@ -69,6 +57,57 @@ type Header struct { } ``` +## Protocol + +The default socket communication protocol: + +``` +HeaderLength | HeaderCodecId | Header | BodyLength | BodyCodecId | Body +``` + +**Notes:** + +- `HeaderLength`: uint32, 4 bytes, big endian +- `HeaderCodecId`: uint8, 1 byte +- `Header`: header bytes +- `BodyLength`: uint32, 4 bytes, big endian + * may be 0, meaning that the `Body` is empty and does not indicate the `BodyCodecId` + * may be 1, meaning that the `Body` is empty but indicates the `BodyCodecId` +- `BodyCodecId`: uint8, 1 byte +- `Body`: body bytes + +You can customize your own communication protocol by implementing the interface: + +```go +// Protocol socket communication protocol +type Protocol interface { + // WritePacket writes header and body to the connection. + WritePacket( + packet *Packet, + destWriter *utils.BufioWriter, + tmpCodecWriterGetter func(codecName string) (*TmpCodecWriter, error), + isActiveClosed func() bool, + ) error + + // ReadPacket reads header and body from the connection. + ReadPacket( + packet *Packet, + bodyAdapter func() interface{}, + srcReader *utils.BufioReader, + codecReaderGetter func(codecId byte) (*CodecReader, error), + isActiveClosed func() bool, + ) error +} +``` + +Next, you can specify the communication protocol in the following ways: + +```go +func SetDefaultProtocol(Protocol) +func GetSocket(net.Conn, ...Protocol) Socket +func NewSocket(net.Conn, ...Protocol) Socket +``` + ## Demo ### server.go diff --git a/socket/example/server.go b/socket/example/server.go index e8e04fe9..f87676ae 100644 --- a/socket/example/server.go +++ b/socket/example/server.go @@ -48,8 +48,9 @@ func main() { err = s.WritePacket(packet) if err != nil { log.Printf("[SVR] write response err: %v", err) + } else { + // log.Printf("[SVR] write response: %v", packet) } - // log.Printf("[SVR] write response: %v", packet) socket.PutPacket(packet) } }(socket.GetSocket(conn)) diff --git a/socket/gzip.go b/socket/gzip.go deleted file mode 100644 index 91c7867f..00000000 --- a/socket/gzip.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2017 HenryLee. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package socket - -import ( - "compress/gzip" - "io" - - "github.com/henrylee2cn/teleport/codec" - "github.com/henrylee2cn/teleport/utils" -) - -type GzipEncoder struct { - id byte - gzipWriterMap map[int]*gzip.Writer - w io.Writer - encMap map[int]codec.Encoder - encMaker func(io.Writer) codec.Encoder -} - -func (s *socket) getTmpGzipEncoder(codecName string) (*GzipEncoder, error) { - g, ok := s.gzipEncodeMap[codecName] - if ok { - return g, nil - } - c, err := codec.GetByName(codecName) - if err != nil { - return nil, err - } - w := s.tmpWriter - enc := c.NewEncoder(w) - g = &GzipEncoder{ - id: c.Id(), - gzipWriterMap: s.gzipWriterMap, - w: w, - encMap: map[int]codec.Encoder{gzip.NoCompression: enc}, - encMaker: c.NewEncoder, - } - s.gzipEncodeMap[codecName] = g - return g, nil -} - -func (g *GzipEncoder) Id() byte { - return g.id -} - -func (g *GzipEncoder) Encode(gzipLevel int, v interface{}) error { - enc, ok := g.encMap[gzipLevel] - if gzipLevel == gzip.NoCompression { - return enc.Encode(v) - } - var gz *gzip.Writer - var err error - if ok { - gz = g.gzipWriterMap[gzipLevel] - gz.Reset(g.w) - - } else { - gz, err = gzip.NewWriterLevel(g.w, gzipLevel) - if err != nil { - return err - } - g.gzipWriterMap[gzipLevel] = gz - enc = g.encMaker(gz) - g.encMap[gzipLevel] = enc - } - - if err = enc.Encode(v); err != nil { - return err - } - return gz.Flush() -} - -type GzipDecoder struct { - name string - gzipReader *gzip.Reader - r utils.Reader - dec codec.Decoder - gzDec codec.Decoder - decMaker func(io.Reader) codec.Decoder -} - -func (s *socket) getGzipDecoder(codecId byte) (*GzipDecoder, error) { - g, ok := s.gzipDecodeMap[codecId] - if ok { - return g, nil - } - c, err := codec.GetById(codecId) - if err != nil { - return nil, err - } - r := s.limitReader - gzipReader := s.gzipReader - g = &GzipDecoder{ - r: r, - gzipReader: gzipReader, - decMaker: c.NewDecoder, - dec: c.NewDecoder(r), - gzDec: c.NewDecoder(gzipReader), - name: c.Name(), - } - s.gzipDecodeMap[codecId] = g - return g, nil -} - -func (g *GzipDecoder) Name() string { - return g.name -} - -func (g *GzipDecoder) Decode(gzipLevel int, v interface{}) error { - if gzipLevel == gzip.NoCompression { - return g.dec.Decode(v) - } - var err error - if err = g.gzipReader.Reset(g.r); err != nil { - return err - } - return g.gzDec.Decode(v) -} diff --git a/socket/packet.go b/socket/packet.go index 869bd5d8..73d51d26 100644 --- a/socket/packet.go +++ b/socket/packet.go @@ -150,7 +150,7 @@ func (p *Packet) ResetBodyGetting(bodyGetting func(*Header) interface{}) { p.bodyGetting = bodyGetting } -func (p *Packet) loadBody() interface{} { +func (p *Packet) bodyAdapter() interface{} { if p.bodyGetting != nil { p.Body = p.bodyGetting(p.Header) } else { diff --git a/socket/protocol.go b/socket/protocol.go new file mode 100644 index 00000000..f9d17f44 --- /dev/null +++ b/socket/protocol.go @@ -0,0 +1,327 @@ +// Socket package provides a concise, powerful and high-performance TCP +// +// Copyright 2017 HenryLee. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "compress/gzip" + "encoding/binary" + "io" + "io/ioutil" + + "github.com/henrylee2cn/teleport/codec" + "github.com/henrylee2cn/teleport/utils" +) + +// Protocol socket communication protocol +type Protocol interface { + // WritePacket writes header and body to the connection. + WritePacket( + packet *Packet, + destWriter *utils.BufioWriter, + tmpCodecWriterGetter func(codecName string) (*TmpCodecWriter, error), + isActiveClosed func() bool, + ) error + + // ReadPacket reads header and body from the connection. + ReadPacket( + packet *Packet, + bodyAdapter func() interface{}, + srcReader *utils.BufioReader, + codecReaderGetter func(codecId byte) (*CodecReader, error), + isActiveClosed func() bool, + ) error +} + +// default socket communication protocol +var ( + ProtoLee Protocol = new(protoLee) + defaultProtocol Protocol = ProtoLee +) + +// GetDefaultProtocol gets the default socket communication protocol +func GetDefaultProtocol() Protocol { + return defaultProtocol +} + +// SetDefaultProtocol sets the default socket communication protocol +func SetDefaultProtocol(protocol Protocol) { + defaultProtocol = protocol +} + +type protoLee struct{} + +// WritePacket writes header and body to the connection. +// WritePacket can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetWriteDeadline. +// Note: +// For the byte stream type of body, write directly, do not do any processing; +// Must be safe for concurrent use by multiple goroutines. +func (p protoLee) WritePacket( + packet *Packet, + destWriter *utils.BufioWriter, + tmpCodecWriterGetter func(codecName string) (*TmpCodecWriter, error), + isActiveClosed func() bool, +) error { + + // write header + if len(packet.HeaderCodec) == 0 { + packet.HeaderCodec = defaultHeaderCodec.Name() + } + tmpCodecWriter, err := tmpCodecWriterGetter(packet.HeaderCodec) + if err != nil { + return err + } + err = p.writeHeader(destWriter, tmpCodecWriter, packet.Header) + packet.HeaderLength = destWriter.Count() + packet.Length = packet.HeaderLength + packet.BodyLength = 0 + if err != nil { + return err + } + + // write body + defer func() { + packet.Length = destWriter.Count() + packet.BodyLength = packet.Length - packet.HeaderLength + }() + + switch bo := packet.Body.(type) { + case nil: + codecId := GetCodecId(packet.BodyCodec) + if codecId == codec.NilCodecId { + err = binary.Write(destWriter, binary.BigEndian, uint32(0)) + } else { + err = binary.Write(destWriter, binary.BigEndian, uint32(1)) + if err == nil { + err = binary.Write(destWriter, binary.BigEndian, codecId) + } + } + + case []byte: + err = p.writeBytesBody(destWriter, bo) + case *[]byte: + err = p.writeBytesBody(destWriter, *bo) + default: + if len(packet.BodyCodec) == 0 { + packet.BodyCodec = defaultBodyCodec.Name() + } + tmpCodecWriter, err = tmpCodecWriterGetter(packet.BodyCodec) + if err == nil { + err = p.writeBody(destWriter, tmpCodecWriter, int(packet.Header.Gzip), bo) + } + } + if err != nil { + return err + } + return destWriter.Flush() +} + +func (protoLee) writeHeader(destWriter *utils.BufioWriter, tmpCodecWriter *TmpCodecWriter, header *Header) error { + err := binary.Write(tmpCodecWriter, binary.BigEndian, tmpCodecWriter.Id()) + if err != nil { + return err + } + err = tmpCodecWriter.Encode(gzip.NoCompression, header) + if err != nil { + return err + } + headerSize := uint32(tmpCodecWriter.Len()) + err = binary.Write(destWriter, binary.BigEndian, headerSize) + if err != nil { + return err + } + _, err = tmpCodecWriter.WriteTo(destWriter) + return err +} + +func (protoLee) writeBytesBody(destWriter *utils.BufioWriter, body []byte) error { + bodySize := uint32(len(body)) + err := binary.Write(destWriter, binary.BigEndian, bodySize) + if err != nil { + return err + } + _, err = destWriter.Write(body) + return err +} + +func (protoLee) writeBody(destWriter *utils.BufioWriter, tmpCodecWriter *TmpCodecWriter, gzipLevel int, body interface{}) error { + err := binary.Write(tmpCodecWriter, binary.BigEndian, tmpCodecWriter.Id()) + if err != nil { + return err + } + err = tmpCodecWriter.Encode(gzipLevel, body) + if err != nil { + return err + } + // write body to socket buffer + bodySize := uint32(tmpCodecWriter.Len()) + err = binary.Write(destWriter, binary.BigEndian, bodySize) + if err != nil { + return err + } + _, err = tmpCodecWriter.WriteTo(destWriter) + return err +} + +// ReadPacket reads header and body from the connection. +// Note: +// For the byte stream type of body, read directly, do not do any processing; +// Must be safe for concurrent use by multiple goroutines. +func (p protoLee) ReadPacket( + packet *Packet, + bodyAdapter func() interface{}, + srcReader *utils.BufioReader, + codecReaderGetter func(codecId byte) (*CodecReader, error), + isActiveClosed func() bool, +) error { + + var ( + hErr, bErr error + b interface{} + ) + packet.HeaderLength, packet.HeaderCodec, hErr = p.readHeader(srcReader, codecReaderGetter, packet.Header) + if hErr == nil { + b = bodyAdapter() + } else { + if hErr == io.EOF || hErr == io.ErrUnexpectedEOF { + packet.Length = packet.HeaderLength + packet.BodyLength = 0 + packet.BodyCodec = "" + return hErr + } else if isActiveClosed() { + packet.Length = packet.HeaderLength + packet.BodyLength = 0 + packet.BodyCodec = "" + return ErrProactivelyCloseSocket + } + } + + packet.BodyLength, packet.BodyCodec, bErr = p.readBody(srcReader, codecReaderGetter, int(packet.Header.Gzip), b) + packet.Length = packet.HeaderLength + packet.BodyLength + if isActiveClosed() { + return ErrProactivelyCloseSocket + } + return bErr +} + +// readHeader reads header from the connection. +// readHeader can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +// Note: must use only one goroutine call. +func (protoLee) readHeader( + srcReader *utils.BufioReader, + codecReaderGetter func(byte) (*CodecReader, error), + header *Header, +) (int64, string, error) { + + srcReader.ResetCount() + srcReader.ResetLimit(-1) + + var headerSize uint32 + err := binary.Read(srcReader, binary.BigEndian, &headerSize) + if err != nil { + return srcReader.Count(), "", err + } + + srcReader.ResetLimit(int64(headerSize)) + + var codecId = codec.NilCodecId + + err = binary.Read(srcReader, binary.BigEndian, &codecId) + if err != nil { + return srcReader.Count(), GetCodecName(codecId), err + } + + codecReader, err := codecReaderGetter(codecId) + if err != nil { + return srcReader.Count(), GetCodecName(codecId), err + } + + err = codecReader.Decode(gzip.NoCompression, header) + return srcReader.Count(), codecReader.Name(), err +} + +// readBody reads body from the connection. +// readBody can be made to time out and return an Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +// Note: must use only one goroutine call, and it must be called after calling the readHeader(). +func (protoLee) readBody( + srcReader *utils.BufioReader, + codecReaderGetter func(byte) (*CodecReader, error), + gzipLevel int, + body interface{}, +) (int64, string, error) { + + srcReader.ResetCount() + srcReader.ResetLimit(-1) + + var ( + bodySize uint32 + codecId = codec.NilCodecId + ) + + err := binary.Read(srcReader, binary.BigEndian, &bodySize) + if err != nil { + return srcReader.Count(), "", err + } + if bodySize == 0 { + return srcReader.Count(), "", err + } + + srcReader.ResetLimit(int64(bodySize)) + + // read body + switch bo := body.(type) { + case nil: + var codecName string + codecName, err = readAll(srcReader, make([]byte, 1)) + return srcReader.Count(), codecName, err + + case []byte: + var codecName string + codecName, err = readAll(srcReader, bo) + return srcReader.Count(), codecName, err + + case *[]byte: + *bo, err = ioutil.ReadAll(srcReader) + return srcReader.Count(), GetCodecNameFromBytes(*bo), err + + default: + err = binary.Read(srcReader, binary.BigEndian, &codecId) + if bodySize == 1 || err != nil { + return srcReader.Count(), GetCodecName(codecId), err + } + codecReader, err := codecReaderGetter(codecId) + if err != nil { + return srcReader.Count(), GetCodecName(codecId), err + } + err = codecReader.Decode(gzipLevel, body) + return srcReader.Count(), codecReader.Name(), err + } +} + +func readAll(reader io.Reader, p []byte) (string, error) { + perLen := len(p) + _, err := reader.Read(p[perLen:]) + if err == nil { + _, err = io.Copy(ioutil.Discard, reader) + } + if len(p) > perLen { + return GetCodecName(p[perLen]), err + } + return "", err +} diff --git a/socket/rw.go b/socket/rw.go new file mode 100644 index 00000000..496dfb7f --- /dev/null +++ b/socket/rw.go @@ -0,0 +1,135 @@ +// Copyright 2017 HenryLee. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package socket + +import ( + "bytes" + "compress/gzip" + "io" + + "github.com/henrylee2cn/teleport/codec" + "github.com/henrylee2cn/teleport/utils" +) + +// TmpCodecWriter writer with gzip and encoder +type TmpCodecWriter struct { + *bytes.Buffer + id byte + tmpGzipWriterMap map[int]*gzip.Writer + encMap map[int]codec.Encoder + encMaker func(io.Writer) codec.Encoder +} + +// Note: reseting the temporary buffer when return the *TmpCodecWriter +func (s *socket) getTmpCodecWriter(codecName string) (*TmpCodecWriter, error) { + w := s.tmpBufferWriter + w.Reset() + t, ok := s.tmpCodecWriterMap[codecName] + if ok { + return t, nil + } + c, err := codec.GetByName(codecName) + if err != nil { + return nil, err + } + enc := c.NewEncoder(w) + t = &TmpCodecWriter{ + id: c.Id(), + tmpGzipWriterMap: s.tmpGzipWriterMap, + Buffer: w, + encMap: map[int]codec.Encoder{gzip.NoCompression: enc}, + encMaker: c.NewEncoder, + } + s.tmpCodecWriterMap[codecName] = t + return t, nil +} + +// Id returns codec id. +func (t *TmpCodecWriter) Id() byte { + return t.id +} + +// Encode writes data with gzip and encoder. +func (t *TmpCodecWriter) Encode(gzipLevel int, v interface{}) error { + enc, ok := t.encMap[gzipLevel] + if gzipLevel == gzip.NoCompression { + return enc.Encode(v) + } + var gz *gzip.Writer + var err error + if ok { + gz = t.tmpGzipWriterMap[gzipLevel] + gz.Reset(t.Buffer) + + } else { + gz, err = gzip.NewWriterLevel(t.Buffer, gzipLevel) + if err != nil { + return err + } + t.tmpGzipWriterMap[gzipLevel] = gz + enc = t.encMaker(gz) + t.encMap[gzipLevel] = enc + } + + if err = enc.Encode(v); err != nil { + return err + } + return gz.Flush() +} + +type CodecReader struct { + *utils.BufioReader + name string + gzipReader *gzip.Reader + dec codec.Decoder + gzDec codec.Decoder +} + +func (s *socket) getCodecReader(codecId byte) (*CodecReader, error) { + r, ok := s.codecReaderMap[codecId] + if ok { + return r, nil + } + c, err := codec.GetById(codecId) + if err != nil { + return nil, err + } + bufioReader := s.bufioReader + gzipReader := s.gzipReader + r = &CodecReader{ + BufioReader: bufioReader, + gzipReader: gzipReader, + dec: c.NewDecoder(bufioReader), + gzDec: c.NewDecoder(gzipReader), + name: c.Name(), + } + s.codecReaderMap[codecId] = r + return r, nil +} + +func (r *CodecReader) Name() string { + return r.name +} + +func (r *CodecReader) Decode(gzipLevel int, v interface{}) error { + if gzipLevel == gzip.NoCompression { + return r.dec.Decode(v) + } + var err error + if err = r.gzipReader.Reset(r.BufioReader); err != nil { + return err + } + return r.gzDec.Decode(v) +} diff --git a/socket/socket.go b/socket/socket.go index 7668e7e3..618dd8e1 100644 --- a/socket/socket.go +++ b/socket/socket.go @@ -29,9 +29,6 @@ package socket import ( "bytes" "compress/gzip" - "encoding/binary" - "io" - "io/ioutil" "net" "sync" "sync/atomic" @@ -39,8 +36,6 @@ import ( "github.com/henrylee2cn/goutil" "github.com/henrylee2cn/goutil/errors" - - "github.com/henrylee2cn/teleport/codec" "github.com/henrylee2cn/teleport/utils" ) @@ -85,8 +80,6 @@ type ( SetWriteDeadline(t time.Time) error // WritePacket writes header and body to the connection. - // WritePacket can be made to time out and return an Error with Timeout() == true - // after a fixed time limit; see SetDeadline and SetWriteDeadline. // Note: must be safe for concurrent use by multiple goroutines. WritePacket(packet *Packet) error @@ -122,28 +115,30 @@ type ( } socket struct { net.Conn - id string - idMutex sync.RWMutex - bufioWriter *utils.BufioWriter - bufioReader *utils.BufioReader - - tmpWriter *bytes.Buffer - limitReader *utils.LimitedReader - - gzipWriterMap map[int]*gzip.Writer - gzipReader *gzip.Reader - gzipEncodeMap map[string]*GzipEncoder // codecId:GzipEncoder - gzipDecodeMap map[byte]*GzipDecoder // codecId:GzipDecoder - ctxPublic goutil.Map - writeMutex sync.Mutex // exclusive writer lock - readMutex sync.Mutex // exclusive read lock - curState int32 - fromPool bool + id string + idMutex sync.RWMutex + ctxPublic goutil.Map + protocol Protocol + curState int32 + fromPool bool + + // about write + bufioWriter *utils.BufioWriter + tmpCodecWriterMap map[string]*TmpCodecWriter // codecId:TmpCodecWriter + tmpGzipWriterMap map[int]*gzip.Writer + tmpBufferWriter *bytes.Buffer + writeMutex sync.Mutex // exclusive writer lock + + // about read + bufioReader *utils.BufioReader + codecReaderMap map[byte]*CodecReader // codecId:CodecReader + gzipReader *gzip.Reader + readMutex sync.Mutex // exclusive read lock } ) const ( - idle int32 = 0 + normal int32 = 0 activeClose int32 = 1 ) @@ -153,37 +148,39 @@ var _ net.Conn = Socket(nil) var ErrProactivelyCloseSocket = errors.New("socket is closed proactively") // GetSocket gets a Socket from pool, and reset it. -func GetSocket(c net.Conn, id ...string) Socket { +func GetSocket(c net.Conn, protocol ...Protocol) Socket { s := socketPool.Get().(*socket) - s.Reset(c, id...) + s.Reset(c, protocol...) return s } var socketPool = sync.Pool{ New: func() interface{} { - s := NewSocket(nil) + s := newSocket(nil, nil) s.fromPool = true return s }, } // NewSocket wraps a net.Conn as a Socket. -func NewSocket(c net.Conn, id ...string) *socket { +func NewSocket(c net.Conn, protocol ...Protocol) Socket { + return newSocket(c, protocol) +} + +func newSocket(c net.Conn, protocols []Protocol) *socket { bufioWriter := utils.NewBufioWriter(c) bufioReader := utils.NewBufioReader(c) - tmpWriter := bytes.NewBuffer(nil) - limitReader := utils.LimitReader(bufioReader, 0) + tmpBufferWriter := bytes.NewBuffer(nil) var s = &socket{ - id: getId(c, id), - Conn: c, - bufioWriter: bufioWriter, - bufioReader: bufioReader, - tmpWriter: tmpWriter, - limitReader: limitReader, - gzipWriterMap: make(map[int]*gzip.Writer), - gzipReader: new(gzip.Reader), - gzipEncodeMap: make(map[string]*GzipEncoder), - gzipDecodeMap: make(map[byte]*GzipDecoder), + protocol: getProtocol(protocols), + Conn: c, + bufioWriter: bufioWriter, + tmpBufferWriter: tmpBufferWriter, + bufioReader: bufioReader, + tmpGzipWriterMap: make(map[int]*gzip.Writer), + gzipReader: new(gzip.Reader), + tmpCodecWriterMap: make(map[string]*TmpCodecWriter), + codecReaderMap: make(map[byte]*CodecReader), } return s } @@ -197,115 +194,13 @@ func NewSocket(c net.Conn, id ...string) *socket { func (s *socket) WritePacket(packet *Packet) (err error) { s.writeMutex.Lock() defer func() { - if err != nil && atomic.LoadInt32(&s.curState) == activeClose { + if err != nil && s.isActiveClosed() { err = ErrProactivelyCloseSocket } s.writeMutex.Unlock() }() s.bufioWriter.ResetCount() - - if len(packet.HeaderCodec) == 0 { - packet.HeaderCodec = defaultHeaderCodec.Name() - } - - // write header - err = s.writeHeader(packet.HeaderCodec, packet.Header) - packet.HeaderLength = s.bufioWriter.Count() - packet.Length = packet.HeaderLength - packet.BodyLength = 0 - if err != nil { - return err - } - - defer func() { - packet.Length = s.bufioWriter.Count() - packet.BodyLength = packet.Length - packet.HeaderLength - }() - - // write body - switch bo := packet.Body.(type) { - case nil: - codecId := GetCodecId(packet.BodyCodec) - if codecId == codec.NilCodecId { - err = binary.Write(s.bufioWriter, binary.BigEndian, uint32(0)) - } else { - err = binary.Write(s.bufioWriter, binary.BigEndian, uint32(1)) - if err == nil { - err = binary.Write(s.bufioWriter, binary.BigEndian, codecId) - } - } - - case []byte: - err = s.writeBytesBody(bo) - case *[]byte: - err = s.writeBytesBody(*bo) - default: - if len(packet.BodyCodec) == 0 { - packet.BodyCodec = defaultBodyCodec.Name() - } - err = s.writeBody(packet.BodyCodec, int(packet.Header.Gzip), bo) - } - if err != nil { - return err - } - return s.bufioWriter.Flush() -} - -func (s *socket) writeHeader(codecName string, header *Header) error { - s.tmpWriter.Reset() - tmpGzipEncoder, err := s.getTmpGzipEncoder(codecName) - if err != nil { - return err - } - err = binary.Write(s.tmpWriter, binary.BigEndian, tmpGzipEncoder.Id()) - if err != nil { - return err - } - err = tmpGzipEncoder.Encode(gzip.NoCompression, header) - if err != nil { - return err - } - headerSize := uint32(s.tmpWriter.Len()) - err = binary.Write(s.bufioWriter, binary.BigEndian, headerSize) - if err != nil { - return err - } - _, err = s.tmpWriter.WriteTo(s.bufioWriter) - return err -} - -func (s *socket) writeBytesBody(body []byte) error { - bodySize := uint32(len(body)) - err := binary.Write(s.bufioWriter, binary.BigEndian, bodySize) - if err != nil { - return err - } - _, err = s.bufioWriter.Write(body) - return err -} - -func (s *socket) writeBody(codecName string, gzipLevel int, body interface{}) error { - s.tmpWriter.Reset() - tmpGzipEncoder, err := s.getTmpGzipEncoder(codecName) - if err != nil { - return err - } - err = binary.Write(s.tmpWriter, binary.BigEndian, tmpGzipEncoder.Id()) - if err != nil { - return err - } - err = tmpGzipEncoder.Encode(gzipLevel, body) - if err != nil { - return err - } - // write body to socket buffer - bodySize := uint32(s.tmpWriter.Len()) - err = binary.Write(s.bufioWriter, binary.BigEndian, bodySize) - if err != nil { - return err - } - _, err = s.tmpWriter.WriteTo(s.bufioWriter) - return err + return s.protocol.WritePacket(packet, s.bufioWriter, s.getTmpCodecWriter, s.isActiveClosed) } // ReadPacket reads header and body from the connection. @@ -315,129 +210,7 @@ func (s *socket) writeBody(codecName string, gzipLevel int, body interface{}) er func (s *socket) ReadPacket(packet *Packet) error { s.readMutex.Lock() defer s.readMutex.Unlock() - var ( - hErr, bErr error - b interface{} - ) - - packet.HeaderLength, packet.HeaderCodec, hErr = s.readHeader(packet.Header) - if hErr == nil { - b = packet.loadBody() - } else { - if hErr == io.EOF || hErr == io.ErrUnexpectedEOF { - packet.Length = packet.HeaderLength - packet.BodyLength = 0 - packet.BodyCodec = "" - return hErr - } else if atomic.LoadInt32(&s.curState) == activeClose { - packet.Length = packet.HeaderLength - packet.BodyLength = 0 - packet.BodyCodec = "" - return ErrProactivelyCloseSocket - } - } - - packet.BodyLength, packet.BodyCodec, bErr = s.readBody(int(packet.Header.Gzip), b) - packet.Length = packet.HeaderLength + packet.BodyLength - if atomic.LoadInt32(&s.curState) == activeClose { - return ErrProactivelyCloseSocket - } - return bErr -} - -// readHeader reads header from the connection. -// readHeader can be made to time out and return an Error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. -// Note: must use only one goroutine call. -func (s *socket) readHeader(header *Header) (int64, string, error) { - s.bufioReader.ResetCount() - var ( - headerSize uint32 - codecId = codec.NilCodecId - ) - - err := binary.Read(s.bufioReader, binary.BigEndian, &headerSize) - if err != nil { - return s.bufioReader.Count(), "", err - } - - s.limitReader.ResetLimit(int64(headerSize)) - - err = binary.Read(s.limitReader, binary.BigEndian, &codecId) - - if err != nil { - return s.bufioReader.Count(), GetCodecName(codecId), err - } - - gd, err := s.getGzipDecoder(codecId) - if err != nil { - return s.bufioReader.Count(), GetCodecName(codecId), err - } - err = gd.Decode(gzip.NoCompression, header) - return s.bufioReader.Count(), gd.Name(), err -} - -// readBody reads body from the connection. -// readBody can be made to time out and return an Error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. -// Note: must use only one goroutine call, and it must be called after calling the readHeader(). -func (s *socket) readBody(gzipLevel int, body interface{}) (int64, string, error) { - s.bufioReader.ResetCount() - var ( - bodySize uint32 - codecId = codec.NilCodecId - ) - - err := binary.Read(s.bufioReader, binary.BigEndian, &bodySize) - if err != nil { - return s.bufioReader.Count(), "", err - } - if bodySize == 0 { - return s.bufioReader.Count(), "", err - } - - s.limitReader.ResetLimit(int64(bodySize)) - - // read body - switch bo := body.(type) { - case nil: - var codecName string - codecName, err = readAll(s.limitReader, make([]byte, 1)) - return s.bufioReader.Count(), codecName, err - - case []byte: - var codecName string - codecName, err = readAll(s.limitReader, bo) - return s.bufioReader.Count(), codecName, err - - case *[]byte: - *bo, err = ioutil.ReadAll(s.limitReader) - return s.bufioReader.Count(), GetCodecNameFromBytes(*bo), err - - default: - err = binary.Read(s.limitReader, binary.BigEndian, &codecId) - if bodySize == 1 || err != nil { - return s.bufioReader.Count(), GetCodecName(codecId), err - } - gd, err := s.getGzipDecoder(codecId) - if err != nil { - return s.bufioReader.Count(), GetCodecName(codecId), err - } - err = gd.Decode(gzipLevel, body) - return s.bufioReader.Count(), gd.Name(), err - } -} - -func readAll(reader io.Reader, p []byte) (string, error) { - perLen := len(p) - _, err := reader.Read(p[perLen:]) - if err == nil { - _, err = io.Copy(ioutil.Discard, reader) - } - if len(p) > perLen { - return GetCodecName(p[perLen]), err - } - return "", err + return s.protocol.ReadPacket(packet, packet.bodyAdapter, s.bufioReader, s.getCodecReader, s.isActiveClosed) } // Public returns temporary public data of Socket. @@ -460,6 +233,9 @@ func (s *socket) PublicLen() int { func (s *socket) Id() string { s.idMutex.RLock() id := s.id + if len(id) == 0 { + id = s.RemoteAddr().String() + } s.idMutex.RUnlock() return id } @@ -472,18 +248,19 @@ func (s *socket) SetId(id string) { } // Reset reset net.Conn -func (s *socket) Reset(netConn net.Conn, id ...string) { +func (s *socket) Reset(netConn net.Conn, protocol ...Protocol) { atomic.StoreInt32(&s.curState, activeClose) if s.Conn != nil { s.Conn.Close() } s.readMutex.Lock() s.writeMutex.Lock() - s.SetId(getId(netConn, id)) + s.SetId("") + s.protocol = getProtocol(protocol) s.Conn = netConn s.bufioReader.Reset(netConn) s.bufioWriter.Reset(netConn) - atomic.StoreInt32(&s.curState, idle) + atomic.StoreInt32(&s.curState, normal) s.readMutex.Unlock() s.writeMutex.Unlock() } @@ -492,7 +269,7 @@ func (s *socket) Reset(netConn net.Conn, id ...string) { // Any blocked Read or Write operations will be unblocked and return errors. // If it is from 'GetSocket()' function(a pool), return itself to pool. func (s *socket) Close() error { - if atomic.LoadInt32(&s.curState) == activeClose { + if s.isActiveClosed() { return nil } atomic.StoreInt32(&s.curState, activeClose) @@ -509,12 +286,12 @@ func (s *socket) Close() error { s.writeMutex.Unlock() }() - if atomic.LoadInt32(&s.curState) == activeClose { + if s.isActiveClosed() { return nil } s.closeGzipReader() - for _, gz := range s.gzipWriterMap { + for _, gz := range s.tmpGzipWriterMap { errs = append(errs, gz.Close()) } if s.fromPool { @@ -527,6 +304,10 @@ func (s *socket) Close() error { return errors.Merge(errs...) } +func (s *socket) isActiveClosed() bool { + return atomic.LoadInt32(&s.curState) == activeClose +} + func (s *socket) closeGzipReader() { defer func() { recover() @@ -534,12 +315,20 @@ func (s *socket) closeGzipReader() { s.gzipReader.Close() } -func getId(c net.Conn, ids []string) string { - var id string - if len(ids) > 0 && len(ids[0]) > 0 { - id = ids[0] - } else if c != nil { - id = c.RemoteAddr().String() +func getProtocol(protocols []Protocol) Protocol { + if len(protocols) > 0 { + return protocols[0] + } else { + return defaultProtocol } - return id } + +// func getId(c net.Conn, ids []string) string { +// var id string +// if len(ids) > 0 && len(ids[0]) > 0 { +// id = ids[0] +// } else if c != nil { +// id = c.RemoteAddr().String() +// } +// return id +// } diff --git a/utils.go b/utils.go index dcfe5838..ee8efb3b 100644 --- a/utils.go +++ b/utils.go @@ -33,6 +33,14 @@ var GetSenderPacket = socket.GetSenderPacket // func GetReceiverPacket(bodyGetting func(*socket.Header) interface{}) *socket.Packet var GetReceiverPacket = socket.GetReceiverPacket +// GetDefaultProtocol gets the default socket communication protocol +// func GetDefaultProtocol() Protocol +var GetDefaultProtocol = socket.GetDefaultProtocol + +// SetDefaultProtocol sets the default socket communication protocol +// func SetDefaultProtocol(protocol Protocol) +var SetDefaultProtocol = socket.SetDefaultProtocol + // PutPacket puts a *socket.Packet to packet stack. // func PutPacket(p *socket.Packet) var PutPacket = socket.PutPacket diff --git a/utils/bufio_reader.go b/utils/bufio_reader.go new file mode 100644 index 00000000..93291e96 --- /dev/null +++ b/utils/bufio_reader.go @@ -0,0 +1,130 @@ +// Copyright 2015-2017 HenryLee. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "bufio" + "compress/flate" + "io" + "math" +) + +var _ flate.Reader = new(BufioReader) + +type BufioReader struct { + reader *bufio.Reader + count int64 + limit int64 // max bytes remaining +} + +func NewBufioReader(r io.Reader, limit ...int64) *BufioReader { + br := &BufioReader{ + reader: bufio.NewReader(r), + } + if len(limit) > 0 && limit[0] >= 0 { + br.limit = limit[0] + } else { + br.limit = math.MaxInt64 + } + return br +} + +func NewBufioReaderSize(r io.Reader, size int, limit ...int64) *BufioReader { + br := &BufioReader{ + reader: bufio.NewReaderSize(r, size), + } + if len(limit) > 0 { + br.limit = limit[0] + } else { + br.limit = math.MaxInt64 + } + return br +} + +func (b *BufioReader) ResetCount() { + b.count = 0 +} + +func (b *BufioReader) ResetLimit(limit int64) { + if limit < 0 { + b.limit = math.MaxInt64 + } else { + b.limit = limit + } +} + +func (b *BufioReader) Count() int64 { + return b.count +} + +func (b *BufioReader) Buffered() int { + return b.reader.Buffered() +} + +func (b *BufioReader) Discard(n int) (discarded int, err error) { + if b.limit <= 0 { + return 0, io.EOF + } + if b.limit < int64(n) { + n = int(b.limit) + } + discarded, err = b.reader.Discard(n) + count := int64(discarded) + b.count += count + b.limit -= count + return +} + +func (b *BufioReader) Read(p []byte) (int, error) { + if b.limit <= 0 { + return 0, io.EOF + } + if int64(len(p)) > b.limit { + p = p[0:b.limit] + } + n, err := b.reader.Read(p) + count := int64(n) + b.count += count + b.limit -= count + return n, err +} + +func (b *BufioReader) ReadByte() (byte, error) { + if b.limit <= 0 { + return 0, io.EOF + } + a, err := b.reader.ReadByte() + if err == nil { + b.count++ + b.limit-- + } + return a, err +} + +func (b *BufioReader) Reset(r io.Reader) { + b.reader.Reset(r) + b.count = 0 + b.limit = math.MaxInt64 +} + +func (b *BufioReader) WriteTo(w io.Writer) (int64, error) { + if b.limit <= 0 { + return 0, io.EOF + } + n, err := b.reader.WriteTo(w) + b.count += n + b.limit -= n + return n, err +} diff --git a/utils/bufio.go b/utils/bufio_writer.go similarity index 62% rename from utils/bufio.go rename to utils/bufio_writer.go index cb6e7956..ac8f5cfb 100644 --- a/utils/bufio.go +++ b/utils/bufio_writer.go @@ -92,69 +92,3 @@ func (b *BufioWriter) WriteString(s string) (int, error) { b.count += int64(n) return n, err } - -type BufioReader struct { - reader *bufio.Reader - count int64 -} - -func NewBufioReader(r io.Reader) *BufioReader { - return &BufioReader{ - reader: bufio.NewReader(r), - } -} - -func NewBufioReaderSize(r io.Reader, size int) *BufioReader { - return &BufioReader{ - reader: bufio.NewReaderSize(r, size), - } -} - -func (b *BufioReader) ResetCount() { - b.count = 0 -} - -func (b *BufioReader) Count() int64 { - return b.count -} - -func (b *BufioReader) Buffered() int { - return b.reader.Buffered() -} - -func (b *BufioReader) Discard(n int) (discarded int, err error) { - discarded, err = b.reader.Discard(n) - b.count += int64(discarded) - return -} - -func (b *BufioReader) Peek(n int) ([]byte, error) { - a, err := b.reader.Peek(n) - b.count += int64(len(a)) - return a, err -} - -func (b *BufioReader) Read(p []byte) (int, error) { - n, err := b.reader.Read(p) - b.count += int64(n) - return n, err -} - -func (b *BufioReader) ReadByte() (byte, error) { - a, err := b.reader.ReadByte() - if err == nil { - b.count++ - } - return a, err -} - -func (b *BufioReader) Reset(r io.Reader) { - b.reader.Reset(r) - b.count = 0 -} - -func (b *BufioReader) WriteTo(w io.Writer) (int64, error) { - n, err := b.reader.WriteTo(w) - b.count += n - return n, err -} diff --git a/utils/limit_reader.go b/utils/limit_reader.go deleted file mode 100644 index bb31a0f3..00000000 --- a/utils/limit_reader.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2015-2017 HenryLee. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "io" -) - -// The actual read interface needed by NewReader. -// If the passed in io.Reader does not also have ReadByte, -// the NewReader will introduce its own buffering. -type Reader interface { - io.Reader - io.ByteReader -} - -// LimitReader returns a Reader that reads from r -// but stops with EOF after n bytes. -// The underlying implementation is a *LimitedReader. -func LimitReader(r Reader, n int64) *LimitedReader { return &LimitedReader{r, n} } - -// A LimitedBufioReader reads from bufio.Reader but limits the amount of -// data returned to just N bytes. Each call to Read -// updates N to reflect the new amount remaining. -// Read returns EOF when N <= 0 or when the underlying R returns EOF. -type LimitedReader struct { - R Reader // underlying reader - N int64 // max bytes remaining -} - -func (l *LimitedReader) ResetLimit(n int64) { - l.N = n -} - -func (l *LimitedReader) Read(p []byte) (n int, err error) { - if l.N <= 0 { - return 0, io.EOF - } - if int64(len(p)) > l.N { - p = p[0:l.N] - } - n, err = l.R.Read(p) - l.N -= int64(n) - return -} - -func (l *LimitedReader) ReadByte() (byte, error) { - if l.N <= 0 { - return 0, io.EOF - } - l.N-- - return l.R.ReadByte() -}