From 056e2ae72da987a02fa170287e38d4454cd60e87 Mon Sep 17 00:00:00 2001 From: fuxiao <576101059@qq.com> Date: Fri, 26 Aug 2022 14:13:35 +0800 Subject: [PATCH 1/2] support tcp network protocol --- network/tcp/client.go | 57 ++++++++ network/tcp/client_conn.go | 239 +++++++++++++++++++++++++++++++ network/tcp/client_options.go | 18 +++ network/tcp/client_test.go | 42 ++++++ network/tcp/def.go | 11 ++ network/tcp/go.mod | 7 + network/tcp/go.sum | 134 ++++++++++++++++++ network/tcp/pack.go | 103 ++++++++++++++ network/tcp/pack_test.go | 26 ++++ network/tcp/server.go | 141 +++++++++++++++++++ network/tcp/server_conn.go | 249 +++++++++++++++++++++++++++++++++ network/tcp/server_conn_mgr.go | 66 +++++++++ network/tcp/server_options.go | 24 ++++ network/tcp/server_test.go | 44 ++++++ 14 files changed, 1161 insertions(+) create mode 100644 network/tcp/client.go create mode 100644 network/tcp/client_conn.go create mode 100644 network/tcp/client_options.go create mode 100644 network/tcp/client_test.go create mode 100644 network/tcp/def.go create mode 100644 network/tcp/go.mod create mode 100644 network/tcp/go.sum create mode 100644 network/tcp/pack.go create mode 100644 network/tcp/pack_test.go create mode 100644 network/tcp/server.go create mode 100644 network/tcp/server_conn.go create mode 100644 network/tcp/server_conn_mgr.go create mode 100644 network/tcp/server_options.go create mode 100644 network/tcp/server_test.go diff --git a/network/tcp/client.go b/network/tcp/client.go new file mode 100644 index 00000000..b1b089a8 --- /dev/null +++ b/network/tcp/client.go @@ -0,0 +1,57 @@ +package tcp + +import ( + "github.com/dobyte/due/network" + "net" +) + +type client struct { + opts *clientOptions // 配置 + connectHandler network.ConnectHandler // 连接打开hook函数 + disconnectHandler network.DisconnectHandler // 连接关闭hook函数 + receiveHandler network.ReceiveHandler // 接收消息hook函数 +} + +var _ network.Client = &client{} + +func NewClient(opts ...ClientOption) network.Client { + o := &clientOptions{ + addr: "127.0.0.1:3553", + maxMsgLength: 1024 * 1024, + } + for _, opt := range opts { + opt(o) + } + + return &client{opts: o} +} + +// Dial 拨号连接 +func (c *client) Dial() (network.Conn, error) { + addr, err := net.ResolveTCPAddr("tcp", c.opts.addr) + if err != nil { + return nil, err + } + + conn, err := net.Dial(addr.Network(), addr.String()) + if err != nil { + return nil, err + } + + return newClientConn(c, conn), nil +} + +// OnConnect 监听连接打开 +func (c *client) OnConnect(handler network.ConnectHandler) { + c.connectHandler = handler +} + +// OnDisconnect 监听连接关闭 +func (c *client) OnDisconnect(handler network.DisconnectHandler) { + c.disconnectHandler = handler +} + +// OnReceive 监听接收到消息 +func (c *client) OnReceive(handler network.ReceiveHandler) { + c.receiveHandler = handler +} diff --git a/network/tcp/client_conn.go b/network/tcp/client_conn.go new file mode 100644 index 00000000..1fad75f3 --- /dev/null +++ b/network/tcp/client_conn.go @@ -0,0 +1,239 @@ +package tcp + +import ( + "github.com/dobyte/due/internal/xnet" + "github.com/dobyte/due/log" + "github.com/dobyte/due/network" + "net" + "sync" + "sync/atomic" +) + +type clientConn struct { + rw sync.RWMutex + id int64 // 连接ID + uid int64 // 用户ID + conn net.Conn // TCP源连接 + state int32 // 连接状态 + client *client // 客户端 + chWrite chan chWrite // 写入队列 + done chan struct{} // 写入完成信号 +} + +var _ network.Conn = &clientConn{} + +func newClientConn(client *client, conn net.Conn) network.Conn { + c := &clientConn{ + id: 1, + conn: conn, + state: int32(network.ConnOpened), + client: client, + chWrite: make(chan chWrite), + done: make(chan struct{}), + } + + if c.client.connectHandler != nil { + c.client.connectHandler(c) + } + + go c.read() + + go c.write() + + return c +} + +// ID 获取连接ID +func (c *clientConn) ID() int64 { + return c.id +} + +// UID 获取用户ID +func (c *clientConn) UID() int64 { + c.rw.RLock() + defer c.rw.RUnlock() + + return c.uid +} + +// Bind 绑定用户ID +func (c *clientConn) Bind(uid int64) { + c.rw.Lock() + defer c.rw.Unlock() + + c.uid = uid +} + +// Send 发送消息(同步) +func (c *clientConn) Send(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + _, err := c.conn.Write(msg) + return err +} + +// Push 发送消息(异步) +func (c *clientConn) Push(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + c.chWrite <- chWrite{typ: dataPacket, msg: msg} + + return nil +} + +// State 获取连接状态 +func (c *clientConn) State() network.ConnState { + return network.ConnState(atomic.LoadInt32(&c.state)) +} + +// Close 关闭连接 +func (c *clientConn) Close(isForce ...bool) error { + c.rw.Lock() + defer c.rw.Unlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(isForce) > 0 && isForce[0] { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + } else { + atomic.StoreInt32(&c.state, int32(network.ConnHanged)) + c.chWrite <- chWrite{typ: closeSig} + <-c.done + } + + close(c.chWrite) + + return c.conn.Close() +} + +// LocalIP 获取本地IP +func (c *clientConn) LocalIP() (string, error) { + addr, err := c.LocalAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// LocalAddr 获取本地地址 +func (c *clientConn) LocalAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.LocalAddr(), nil +} + +// RemoteIP 获取远端IP +func (c *clientConn) RemoteIP() (string, error) { + addr, err := c.RemoteAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// RemoteAddr 获取远端地址 +func (c *clientConn) RemoteAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.RemoteAddr(), nil +} + +// 关闭连接 +func (c *clientConn) close() { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + + if c.client.disconnectHandler != nil { + c.client.disconnectHandler(c) + } +} + +// 检测连接状态 +func (c *clientConn) checkState() error { + switch network.ConnState(atomic.LoadInt32(&c.state)) { + case network.ConnHanged: + return network.ErrConnectionHanged + case network.ConnClosed: + return network.ErrConnectionClosed + } + + return nil +} + +// 读取消息 +func (c *clientConn) read() { + defer c.close() + + for { + msg, err := readMsgFromConn(c.conn, c.client.opts.maxMsgLength) + if err != nil { + if err == errMsgSizeTooLarge { + log.Warnf("the msg size too large, has been ignored") + continue + } + return + } + + switch c.State() { + case network.ConnHanged: + continue + case network.ConnClosed: + return + } + + if c.client.receiveHandler != nil { + c.client.receiveHandler(c, msg, 0) + } + } +} + +// 写入消息 +func (c *clientConn) write() { + for { + select { + case write, ok := <-c.chWrite: + if !ok { + return + } + + if write.typ == closeSig { + c.done <- struct{}{} + return + } + + buf, err := pack(write.msg) + if err != nil { + log.Errorf("packet message error: %v", err) + continue + } + + if _, err = c.conn.Write(buf); err != nil { + log.Errorf("write message error: %v", err) + continue + } + } + } +} diff --git a/network/tcp/client_options.go b/network/tcp/client_options.go new file mode 100644 index 00000000..c6667b20 --- /dev/null +++ b/network/tcp/client_options.go @@ -0,0 +1,18 @@ +package tcp + +type ClientOption func(o *clientOptions) + +type clientOptions struct { + addr string // 地址 + maxMsgLength int // 最大消息长度 +} + +// WithClientDialAddr 设置拨号地址 +func WithClientDialAddr(addr string) ClientOption { + return func(o *clientOptions) { o.addr = addr } +} + +// WithClientMaxMsgLength 设置消息最大长度 +func WithClientMaxMsgLength(maxMsgLength int) ClientOption { + return func(o *clientOptions) { o.maxMsgLength = maxMsgLength } +} diff --git a/network/tcp/client_test.go b/network/tcp/client_test.go new file mode 100644 index 00000000..6fdcb2f2 --- /dev/null +++ b/network/tcp/client_test.go @@ -0,0 +1,42 @@ +package tcp_test + +import ( + "github.com/dobyte/due/network" + "github.com/dobyte/due/network/tcp" + "testing" + "time" +) + +func TestNewClient(t *testing.T) { + client := tcp.NewClient( + tcp.WithClientDialAddr("127.0.0.1:3553"), + ) + + client.OnConnect(func(conn network.Conn) { + t.Log("connection is opened") + }) + client.OnDisconnect(func(conn network.Conn) { + t.Log("connection is closed") + }) + client.OnReceive(func(conn network.Conn, msg []byte, msgType int) { + t.Logf("receive msg from server, msg: %s", string(msg)) + }) + + conn, err := client.Dial() + if err != nil { + t.Fatal(err) + } + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + defer conn.Close() + for { + select { + case <-ticker.C: + if err = conn.Push([]byte("hello server~~")); err != nil { + t.Error(err) + return + } + } + } +} diff --git a/network/tcp/def.go b/network/tcp/def.go new file mode 100644 index 00000000..9b337f2e --- /dev/null +++ b/network/tcp/def.go @@ -0,0 +1,11 @@ +package tcp + +const ( + closeSig int = iota // 关闭信号 + dataPacket // 数据包 +) + +type chWrite struct { + typ int + msg []byte +} diff --git a/network/tcp/go.mod b/network/tcp/go.mod new file mode 100644 index 00000000..57af4774 --- /dev/null +++ b/network/tcp/go.mod @@ -0,0 +1,7 @@ +module github.com/dobyte/due/network/tcp + +go 1.16 + +require github.com/dobyte/due v0.0.1 + +replace github.com/dobyte/due => ../../ diff --git a/network/tcp/go.sum b/network/tcp/go.sum new file mode 100644 index 00000000..5698c0d0 --- /dev/null +++ b/network/tcp/go.sum @@ -0,0 +1,134 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= +github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220524220425-1d687d428aca/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20220525015930-6ca3db687a9d/go.mod h1:yKyY4AMRwFiC8yMMNaMi+RkCnjZJt9LoWuvhXjMs+To= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.46.2/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/network/tcp/pack.go b/network/tcp/pack.go new file mode 100644 index 00000000..b68b646f --- /dev/null +++ b/network/tcp/pack.go @@ -0,0 +1,103 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/12 10:58 下午 + * @Desc: TODO + */ + +package tcp + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "net" +) + +const ( + msgLenBytes uint32 = 4 // 消息长度字节数 + msgByteOrder string = "little" // 消息字节排序 +) + +var ( + errMsgSizeTooLarge = errors.New("the msg size too large") +) + +// 打包消息 +func pack(msg []byte) (packet []byte, err error) { + var buf bytes.Buffer + buf.Grow(len(msg) + int(msgLenBytes)) + + if err = binary.Write(&buf, byteOrder(), uint32(len(msg))); err != nil { + return + } + + if err = binary.Write(&buf, byteOrder(), msg); err != nil { + return + } + + packet = buf.Bytes() + + return +} + +// 解包消息 +func unpack(packet []byte) (msg []byte, err error) { + var ( + buf = bytes.NewBuffer(packet) + msgLen uint32 + ) + + if err = binary.Read(buf, byteOrder(), &msgLen); err != nil { + return + } + + if msgLen > 0 { + msg = make([]byte, msgLen) + if err = binary.Read(buf, byteOrder(), &msg); err != nil { + return + } + } + + return +} + +// 读取连接数据 +func readMsgFromConn(conn net.Conn, maxMsgLength int) (msg []byte, err error) { + packet := make([]byte, msgLenBytes) + if _, err = io.ReadFull(conn, packet); err != nil { + return + } + + var ( + buf = bytes.NewBuffer(packet) + msgLen uint32 + ) + + if err = binary.Read(buf, byteOrder(), &msgLen); err != nil { + return + } + + if msgLen > 0 { + msg = make([]byte, msgLen) + if _, err = io.ReadFull(conn, msg); err != nil { + return + } + + if int(msgLen) > maxMsgLength { + err = errMsgSizeTooLarge + } + } + + return +} + +func byteOrder() binary.ByteOrder { + switch msgByteOrder { + case "little": + return binary.LittleEndian + default: + return binary.BigEndian + } +} diff --git a/network/tcp/pack_test.go b/network/tcp/pack_test.go new file mode 100644 index 00000000..935b0617 --- /dev/null +++ b/network/tcp/pack_test.go @@ -0,0 +1,26 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/13 9:06 上午 + * @Desc: TODO + */ + +package tcp + +import "testing" + +func Test_Pack(t *testing.T) { + msg := []byte("hello world") + + packet, err := pack(msg) + if err != nil { + t.Fatal(err) + } + + msg, err = unpack(packet) + if err != nil { + t.Fatal(err) + } + + t.Log(string(msg)) +} diff --git a/network/tcp/server.go b/network/tcp/server.go new file mode 100644 index 00000000..fe04f10d --- /dev/null +++ b/network/tcp/server.go @@ -0,0 +1,141 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/11 10:02 上午 + * @Desc: TODO + */ + +package tcp + +import ( + "net" + "time" + + "github.com/dobyte/due/network" +) + +type server struct { + opts *serverOptions // 配置 + listener net.Listener // 监听器 + connMgr *serverConnMgr // 连接管理器 + startHandler network.StartHandler // 服务器启动hook函数 + stopHandler network.CloseHandler // 服务器关闭hook函数 + connectHandler network.ConnectHandler // 连接打开hook函数 + disconnectHandler network.DisconnectHandler // 连接关闭hook函数 + receiveHandler network.ReceiveHandler // 接收消息hook函数 +} + +var _ network.Server = &server{} + +func NewServer(opts ...ServerOption) network.Server { + o := &serverOptions{ + addr: ":3553", + maxConnNum: 5000, + maxMsgLength: 1024 * 1024, + } + for _, opt := range opts { + opt(o) + } + + s := &server{} + s.opts = o + s.connMgr = newConnMgr(s) + + return s +} + +// Addr 监听地址 +func (s *server) Addr() string { + return s.opts.addr +} + +// Start 启动服务器 +func (s *server) Start() error { + if err := s.init(); err != nil { + return err + } + + if s.startHandler != nil { + s.startHandler() + } + + go s.serve() + + return nil +} + +// Stop 关闭服务器 +func (s *server) Stop() error { + if err := s.listener.Close(); err != nil { + return err + } + + s.connMgr.close() + + return nil +} + +// Protocol 协议 +func (s *server) Protocol() string { + return "tcp" +} + +// OnStart 监听服务器启动 +func (s *server) OnStart(handler network.StartHandler) { + s.startHandler = handler +} + +// OnStop 监听服务器关闭 +func (s *server) OnStop(handler network.CloseHandler) { + s.stopHandler = handler +} + +// OnConnect 监听连接打开 +func (s *server) OnConnect(handler network.ConnectHandler) { + s.connectHandler = handler +} + +// OnDisconnect 监听连接关闭 +func (s *server) OnDisconnect(handler network.DisconnectHandler) { + s.disconnectHandler = handler +} + +// OnReceive 监听接收到消息 +func (s *server) OnReceive(handler network.ReceiveHandler) { + s.receiveHandler = handler +} + +// 初始化TCP服务器 +func (s *server) init() error { + addr, err := net.ResolveTCPAddr("tcp", s.opts.addr) + if err != nil { + return err + } + + ln, err := net.ListenTCP(addr.Network(), addr) + if err != nil { + return err + } + + s.listener = ln + + return nil +} + +// 等待连接 +func (s *server) serve() { + for { + conn, err := s.listener.Accept() + if err != nil { + if e, ok := err.(net.Error); ok && e.Timeout() { + time.Sleep(time.Millisecond) + continue + } + return + } + + if err := s.connMgr.allocate(conn); err != nil { + _ = conn.Close() + } + } +} diff --git a/network/tcp/server_conn.go b/network/tcp/server_conn.go new file mode 100644 index 00000000..86e88776 --- /dev/null +++ b/network/tcp/server_conn.go @@ -0,0 +1,249 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/11 10:49 上午 + * @Desc: TODO + */ + +package tcp + +import ( + "github.com/dobyte/due/internal/xnet" + "net" + "sync" + "sync/atomic" + + "github.com/dobyte/due/log" + "github.com/dobyte/due/network" +) + +type serverConn struct { + rw sync.RWMutex // 锁 + id int64 // 连接ID + uid int64 // 用户ID + state int32 // 连接状态 + conn net.Conn // TCP源连接 + connMgr *serverConnMgr // 连接管理 + chWrite chan chWrite // 写入队列 + done chan struct{} // 写入完成信号 +} + +var _ network.Conn = &serverConn{} + +// ID 获取连接ID +func (c *serverConn) ID() int64 { + return c.id +} + +// UID 获取用户ID +func (c *serverConn) UID() int64 { + c.rw.RLock() + defer c.rw.RUnlock() + + return c.uid +} + +// Bind 绑定用户ID +func (c *serverConn) Bind(uid int64) { + c.rw.Lock() + defer c.rw.Unlock() + + c.uid = uid +} + +// Send 发送消息(同步) +func (c *serverConn) Send(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + _, err := c.conn.Write(msg) + + return err +} + +// Push 发送消息(异步) +func (c *serverConn) Push(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + c.chWrite <- chWrite{typ: dataPacket, msg: msg} + + return nil +} + +// State 获取连接状态 +func (c *serverConn) State() network.ConnState { + return network.ConnState(atomic.LoadInt32(&c.state)) +} + +// Close 关闭连接 +func (c *serverConn) Close(isForce ...bool) error { + c.rw.Lock() + defer c.rw.Unlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(isForce) > 0 && isForce[0] { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + } else { + atomic.StoreInt32(&c.state, int32(network.ConnHanged)) + c.chWrite <- chWrite{typ: closeSig} + <-c.done + } + + close(c.chWrite) + + err := c.conn.Close() + c.conn = nil + c.connMgr.recycle(c) + + return err +} + +// 关闭连接 +func (c *serverConn) close() { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + + if c.connMgr.server.disconnectHandler != nil { + c.connMgr.server.disconnectHandler(c) + } +} + +// LocalIP 获取本地IP +func (c *serverConn) LocalIP() (string, error) { + addr, err := c.LocalAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// LocalAddr 获取本地地址 +func (c *serverConn) LocalAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.LocalAddr(), nil +} + +// RemoteIP 获取远端IP +func (c *serverConn) RemoteIP() (string, error) { + addr, err := c.RemoteAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// RemoteAddr 获取远端地址 +func (c *serverConn) RemoteAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.RemoteAddr(), nil +} + +// 检测连接状态 +func (c *serverConn) checkState() error { + switch network.ConnState(atomic.LoadInt32(&c.state)) { + case network.ConnHanged: + return network.ErrConnectionHanged + case network.ConnClosed: + return network.ErrConnectionClosed + } + + return nil +} + +// 初始化连接 +func (c *serverConn) init(conn net.Conn, cm *serverConnMgr) { + c.id = cm.id + c.conn = conn + c.connMgr = cm + c.chWrite = make(chan chWrite, 256) + c.done = make(chan struct{}) + atomic.StoreInt32(&c.state, int32(network.ConnOpened)) + + if c.connMgr.server.connectHandler != nil { + c.connMgr.server.connectHandler(c) + } + + go c.read() + + go c.write() +} + +// 读取消息 +func (c *serverConn) read() { + defer c.close() + + for { + msg, err := readMsgFromConn(c.conn, c.connMgr.server.opts.maxMsgLength) + if err != nil { + if err == errMsgSizeTooLarge { + log.Warnf("the msg size too large, has been ignored") + continue + } + break + } + + switch c.State() { + case network.ConnHanged: + continue + case network.ConnClosed: + return + } + + if c.connMgr.server.receiveHandler != nil { + c.connMgr.server.receiveHandler(c, msg, 0) + } + } +} + +// 写入消息 +func (c *serverConn) write() { + for { + select { + case write, ok := <-c.chWrite: + if !ok { + return + } + + if write.typ == closeSig { + c.done <- struct{}{} + return + } + + buf, err := pack(write.msg) + if err != nil { + log.Errorf("packet message error: %v", err) + continue + } + + if _, err = c.conn.Write(buf); err != nil { + log.Errorf("write message error: %v", err) + continue + } + } + } +} diff --git a/network/tcp/server_conn_mgr.go b/network/tcp/server_conn_mgr.go new file mode 100644 index 00000000..ac4ac7d6 --- /dev/null +++ b/network/tcp/server_conn_mgr.go @@ -0,0 +1,66 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/15 9:55 下午 + * @Desc: TODO + */ + +package tcp + +import ( + "github.com/dobyte/due/network" + "net" + "sync" +) + +type serverConnMgr struct { + mu sync.Mutex // 连接锁 + id int64 // 连接ID + pool sync.Pool // 连接池 + conns map[net.Conn]*serverConn // 连接集合 + server *server // 服务器 +} + +func newConnMgr(server *server) *serverConnMgr { + return &serverConnMgr{ + server: server, + conns: make(map[net.Conn]*serverConn), + pool: sync.Pool{New: func() interface{} { return &serverConn{} }}, + } +} + +// 关闭连接 +func (cm *serverConnMgr) close() { + cm.mu.Lock() + defer cm.mu.Unlock() + + for _, conn := range cm.conns { + _ = conn.Close(false) + } +} + +// 分配连接 +func (cm *serverConnMgr) allocate(c net.Conn) error { + cm.mu.Lock() + defer cm.mu.Unlock() + + if len(cm.conns) >= cm.server.opts.maxConnNum { + return network.ErrTooManyConnection + } + + cm.id++ + conn := cm.pool.Get().(*serverConn) + conn.init(c, cm) + cm.conns[c] = conn + + return nil +} + +// 回收连接 +func (cm *serverConnMgr) recycle(conn *serverConn) { + cm.mu.Lock() + defer cm.mu.Unlock() + + delete(cm.conns, conn.conn) + cm.pool.Put(conn) +} diff --git a/network/tcp/server_options.go b/network/tcp/server_options.go new file mode 100644 index 00000000..156229dd --- /dev/null +++ b/network/tcp/server_options.go @@ -0,0 +1,24 @@ +package tcp + +type ServerOption func(o *serverOptions) + +type serverOptions struct { + addr string // 监听地址 + maxConnNum int // 最大连接数 + maxMsgLength int // 最大消息长度 +} + +// WithServerListenAddr 设置监听地址 +func WithServerListenAddr(addr string) ServerOption { + return func(o *serverOptions) { o.addr = addr } +} + +// WithServerMaxConnNum 设置连接的最大连接数 +func WithServerMaxConnNum(maxConnNum int) ServerOption { + return func(o *serverOptions) { o.maxConnNum = maxConnNum } +} + +// WithServerMaxMsgLength 设置消息最大长度 +func WithServerMaxMsgLength(maxMsgLength int) ServerOption { + return func(o *serverOptions) { o.maxMsgLength = maxMsgLength } +} diff --git a/network/tcp/server_test.go b/network/tcp/server_test.go new file mode 100644 index 00000000..ef81de9b --- /dev/null +++ b/network/tcp/server_test.go @@ -0,0 +1,44 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/11 11:31 上午 + * @Desc: TODO + */ + +package tcp_test + +import ( + "github.com/dobyte/due/network/tcp" + "testing" + + "github.com/dobyte/due/network" +) + +func TestServer(t *testing.T) { + server := tcp.NewServer( + tcp.WithServerListenAddr(":3553"), + tcp.WithServerMaxConnNum(5), + ) + server.OnStart(func() { + t.Logf("server is started") + }) + server.OnConnect(func(conn network.Conn) { + t.Logf("connection is opened, connection id: %d", conn.ID()) + }) + server.OnDisconnect(func(conn network.Conn) { + t.Logf("connection is closed, connection id: %d", conn.ID()) + }) + server.OnReceive(func(conn network.Conn, msg []byte, msgType int) { + t.Logf("receive msg from client, connection id: %d, msg: %s", conn.ID(), string(msg)) + + if err := conn.Push([]byte("I'm fine~~")); err != nil { + t.Error(err) + } + }) + + if err := server.Start(); err != nil { + t.Fatal(err) + } + + select {} +} From 47a661453c41ebc33c6c64876dfc896ddad6961a Mon Sep 17 00:00:00 2001 From: fuxiao <576101059@qq.com> Date: Fri, 26 Aug 2022 14:48:42 +0800 Subject: [PATCH 2/2] support ws network protocol --- network/ws/client.go | 60 ++++++++ network/ws/client_conn.go | 246 ++++++++++++++++++++++++++++++++ network/ws/client_options.go | 26 ++++ network/ws/def.go | 19 +++ network/ws/go.mod | 10 ++ network/ws/server.go | 166 ++++++++++++++++++++++ network/ws/server_conn.go | 256 ++++++++++++++++++++++++++++++++++ network/ws/server_conn_mgr.go | 67 +++++++++ network/ws/server_options.go | 43 ++++++ network/ws/server_test.go | 46 ++++++ 10 files changed, 939 insertions(+) create mode 100644 network/ws/client.go create mode 100644 network/ws/client_conn.go create mode 100644 network/ws/client_options.go create mode 100644 network/ws/def.go create mode 100644 network/ws/go.mod create mode 100644 network/ws/server.go create mode 100644 network/ws/server_conn.go create mode 100644 network/ws/server_conn_mgr.go create mode 100644 network/ws/server_options.go create mode 100644 network/ws/server_test.go diff --git a/network/ws/client.go b/network/ws/client.go new file mode 100644 index 00000000..88b91963 --- /dev/null +++ b/network/ws/client.go @@ -0,0 +1,60 @@ +package ws + +import ( + "github.com/dobyte/due/network" + "github.com/gorilla/websocket" + "time" +) + +type client struct { + opts *clientOptions // 配置 + dialer *websocket.Dialer // 拨号器 + connectHandler network.ConnectHandler // 连接打开hook函数 + disconnectHandler network.DisconnectHandler // 连接关闭hook函数 + receiveHandler network.ReceiveHandler // 接收消息hook函数 +} + +var _ network.Client = &client{} + +func NewClient(opts ...ClientOption) network.Client { + o := &clientOptions{ + url: "ws://127.0.0.1:3553", + maxMsgLength: 1024 * 1024, + handshakeTimeout: 10 * time.Second, + } + for _, opt := range opts { + opt(o) + } + + return &client{ + opts: o, + dialer: &websocket.Dialer{ + HandshakeTimeout: o.handshakeTimeout, + }, + } +} + +// Dial 拨号连接 +func (c *client) Dial() (network.Conn, error) { + conn, _, err := c.dialer.Dial(c.opts.url, nil) + if err != nil { + return nil, err + } + + return newClientConn(c, conn), nil +} + +// OnConnect 监听连接打开 +func (c *client) OnConnect(handler network.ConnectHandler) { + c.connectHandler = handler +} + +// OnDisconnect 监听连接关闭 +func (c *client) OnDisconnect(handler network.DisconnectHandler) { + c.disconnectHandler = handler +} + +// OnReceive 监听接收到消息 +func (c *client) OnReceive(handler network.ReceiveHandler) { + c.receiveHandler = handler +} diff --git a/network/ws/client_conn.go b/network/ws/client_conn.go new file mode 100644 index 00000000..d39425bf --- /dev/null +++ b/network/ws/client_conn.go @@ -0,0 +1,246 @@ +package ws + +import ( + "github.com/dobyte/due/internal/xnet" + "github.com/dobyte/due/log" + "github.com/dobyte/due/network" + "github.com/gorilla/websocket" + "net" + "sync" + "sync/atomic" +) + +type clientConn struct { + rw sync.RWMutex // 锁 + id int64 // 连接ID + uid int64 // 用户ID + conn *websocket.Conn // TCP源连接 + state int32 // 连接状态 + client *client // 客户端 + chWrite chan chWrite // 写入队列 + done chan struct{} // 写入完成信号 +} + +var _ network.Conn = &clientConn{} + +func newClientConn(client *client, conn *websocket.Conn) network.Conn { + c := &clientConn{ + id: 1, + conn: conn, + state: int32(network.ConnOpened), + client: client, + chWrite: make(chan chWrite), + done: make(chan struct{}), + } + + if c.client.connectHandler != nil { + c.client.connectHandler(c) + } + + go c.read() + + go c.write() + + return c +} + +// ID 获取连接ID +func (c *clientConn) ID() int64 { + return c.id +} + +// UID 获取用户ID +func (c *clientConn) UID() int64 { + c.rw.RLock() + defer c.rw.RUnlock() + + return c.uid +} + +// Bind 绑定用户ID +func (c *clientConn) Bind(uid int64) { + c.rw.Lock() + defer c.rw.Unlock() + + c.uid = uid +} + +// Send 发送消息(同步) +func (c *clientConn) Send(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(msgType) == 0 { + msgType = append(msgType, TextMessage) + } + + switch msgType[0] { + case TextMessage, BinaryMessage: + return c.conn.WriteMessage(msgType[0], msg) + default: + return network.ErrIllegalMsgType + } +} + +// Push 发送消息(异步) +func (c *clientConn) Push(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(msgType) == 0 { + msgType = append(msgType, TextMessage) + } + + switch msgType[0] { + case TextMessage, BinaryMessage: + c.chWrite <- chWrite{typ: dataPacket, msg: msg, msgType: msgType[0]} + default: + return network.ErrIllegalMsgType + } + + return nil +} + +// State 获取连接状态 +func (c *clientConn) State() network.ConnState { + return network.ConnState(atomic.LoadInt32(&c.state)) +} + +// Close 关闭连接 +func (c *clientConn) Close(isForce ...bool) error { + c.rw.Lock() + defer c.rw.Unlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(isForce) > 0 && isForce[0] { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + } else { + atomic.StoreInt32(&c.state, int32(network.ConnHanged)) + c.chWrite <- chWrite{typ: closeSig} + <-c.done + } + + close(c.chWrite) + + return c.conn.Close() +} + +// LocalIP 获取本地IP +func (c *clientConn) LocalIP() (string, error) { + addr, err := c.LocalAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// LocalAddr 获取本地地址 +func (c *clientConn) LocalAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.LocalAddr(), nil +} + +// RemoteIP 获取远端IP +func (c *clientConn) RemoteIP() (string, error) { + addr, err := c.RemoteAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// RemoteAddr 获取远端地址 +func (c *clientConn) RemoteAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.RemoteAddr(), nil +} + +// 关闭连接 +func (c *clientConn) close() { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + + if c.client.disconnectHandler != nil { + c.client.disconnectHandler(c) + } +} + +// 检测连接状态 +func (c *clientConn) checkState() error { + switch network.ConnState(atomic.LoadInt32(&c.state)) { + case network.ConnHanged: + return network.ErrConnectionHanged + case network.ConnClosed: + return network.ErrConnectionClosed + } + + return nil +} + +// 读取消息 +func (c *clientConn) read() { + defer c.close() + + for { + msgType, buf, err := c.conn.ReadMessage() + if err != nil { + break + } + + switch c.State() { + case network.ConnHanged: + continue + case network.ConnClosed: + return + } + + if c.client.receiveHandler != nil { + c.client.receiveHandler(c, buf, msgType) + } + } +} + +// 写入消息 +func (c *clientConn) write() { + for { + select { + case write, ok := <-c.chWrite: + if !ok { + return + } + + if write.typ == closeSig { + c.done <- struct{}{} + return + } + + if err := c.conn.WriteMessage(write.msgType, write.msg); err != nil { + log.Errorf("write message error: %v", err) + } + } + } +} diff --git a/network/ws/client_options.go b/network/ws/client_options.go new file mode 100644 index 00000000..880d75ef --- /dev/null +++ b/network/ws/client_options.go @@ -0,0 +1,26 @@ +package ws + +import "time" + +type ClientOption func(o *clientOptions) + +type clientOptions struct { + url string // 地址 + maxMsgLength int // 最大消息长度 + handshakeTimeout time.Duration // 握手超时时间 +} + +// WithClientDialUrl 设置拨号链接 +func WithClientDialUrl(url string) ClientOption { + return func(o *clientOptions) { o.url = url } +} + +// WithClientMaxMsgLength 设置消息最大长度 +func WithClientMaxMsgLength(maxMsgLength int) ClientOption { + return func(o *clientOptions) { o.maxMsgLength = maxMsgLength } +} + +// WithClientHandshakeTimeout 设置握手超时时间 +func WithClientHandshakeTimeout(handshakeTimeout time.Duration) ClientOption { + return func(o *clientOptions) { o.handshakeTimeout = handshakeTimeout } +} diff --git a/network/ws/def.go b/network/ws/def.go new file mode 100644 index 00000000..387e5e6b --- /dev/null +++ b/network/ws/def.go @@ -0,0 +1,19 @@ +package ws + +import "github.com/gorilla/websocket" + +const ( + closeSig int = iota // 关闭信号 + dataPacket // 数据包 +) + +const ( + TextMessage = websocket.TextMessage + BinaryMessage = websocket.BinaryMessage +) + +type chWrite struct { + typ int + msg []byte + msgType int +} diff --git a/network/ws/go.mod b/network/ws/go.mod new file mode 100644 index 00000000..2accccf5 --- /dev/null +++ b/network/ws/go.mod @@ -0,0 +1,10 @@ +module github.com/dobyte/due/network/ws + +go 1.16 + +require ( + github.com/dobyte/due v0.0.1 + github.com/gorilla/websocket v1.5.0 +) + +replace github.com/dobyte/due => ../../ \ No newline at end of file diff --git a/network/ws/server.go b/network/ws/server.go new file mode 100644 index 00000000..42ffdc32 --- /dev/null +++ b/network/ws/server.go @@ -0,0 +1,166 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/3/29 7:45 下午 + * @Desc: Websocket服务器 + */ + +package ws + +import ( + "github.com/dobyte/due/log" + "net" + "net/http" + + "github.com/gorilla/websocket" + + "github.com/dobyte/due/network" +) + +type server struct { + opts *serverOptions // 配置 + listener net.Listener // 监听器 + connMgr *connMgr // 连接管理器 + startHandler network.StartHandler // 服务器启动hook函数 + stopHandler network.CloseHandler // 服务器关闭hook函数 + connectHandler network.ConnectHandler // 连接打开hook函数 + disconnectHandler network.DisconnectHandler // 连接关闭hook函数 + receiveHandler network.ReceiveHandler // 接收消息hook函数 +} + +var _ network.Server = &server{} + +func NewServer(opts ...ServerOption) network.Server { + o := &serverOptions{ + addr: ":3553", + maxConnNum: 5000, + path: "/", + checkOrigin: func(r *http.Request) bool { return true }, + } + for _, opt := range opts { + opt(o) + } + + s := &server{} + s.opts = o + s.connMgr = newConnMgr(s) + + return s +} + +// Addr 监听地址 +func (s *server) Addr() string { + return s.opts.addr +} + +// Protocol 协议 +func (s *server) Protocol() string { + return "websocket" +} + +// Start 启动服务器 +func (s *server) Start() error { + if err := s.init(); err != nil { + return err + } + + if s.startHandler != nil { + s.startHandler() + } + + go s.serve() + + return nil +} + +// Stop 关闭服务器 +func (s *server) Stop() error { + if err := s.listener.Close(); err != nil { + return err + } + + s.connMgr.close() + + return nil +} + +// 初始化服务器 +func (s *server) init() error { + addr, err := net.ResolveTCPAddr("tcp", s.opts.addr) + if err != nil { + return err + } + + ln, err := net.ListenTCP(addr.Network(), addr) + if err != nil { + return err + } + + s.listener = ln + + return nil +} + +// 启动服务器 +func (s *server) serve() { + upgrader := websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, + CheckOrigin: s.opts.checkOrigin, + } + + http.HandleFunc(s.opts.path, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", 405) + return + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Errorf("websocket upgrade error: %v", err) + return + } + + if err := s.connMgr.allocate(conn); err != nil { + _ = conn.Close() + } + }) + + var err error + + if s.opts.certFile != "" && s.opts.keyFile != "" { + err = http.ServeTLS(s.listener, nil, s.opts.certFile, s.opts.keyFile) + } else { + err = http.Serve(s.listener, nil) + } + + if err != nil { + log.Fatalf("websocket server start error: %v\n", err) + } +} + +// OnStart 监听服务器启动 +func (s *server) OnStart(handler network.StartHandler) { + s.startHandler = handler +} + +// OnStop 监听服务器关闭 +func (s *server) OnStop(handler network.CloseHandler) { + s.stopHandler = handler +} + +// OnConnect 监听连接打开 +func (s *server) OnConnect(handler network.ConnectHandler) { + s.connectHandler = handler +} + +// OnDisconnect 监听连接关闭 +func (s *server) OnDisconnect(handler network.DisconnectHandler) { + s.disconnectHandler = handler +} + +// OnReceive 监听接收到消息 +func (s *server) OnReceive(handler network.ReceiveHandler) { + s.receiveHandler = handler +} diff --git a/network/ws/server_conn.go b/network/ws/server_conn.go new file mode 100644 index 00000000..e54a3b00 --- /dev/null +++ b/network/ws/server_conn.go @@ -0,0 +1,256 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/27 5:03 下午 + * @Desc: TODO + */ + +package ws + +import ( + "github.com/dobyte/due/internal/xnet" + "github.com/dobyte/due/log" + "net" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" + + "github.com/dobyte/due/network" +) + +type serverConn struct { + rw sync.RWMutex // 锁 + id int64 // 连接ID + uid int64 // 用户ID + state int32 // 连接状态 + conn *websocket.Conn // WS源连接 + connMgr *connMgr // 连接管理 + chWrite chan chWrite // 写入队列 + done chan struct{} // 写入完成信号 +} + +var _ network.Conn = &serverConn{} + +// ID 获取连接ID +func (c *serverConn) ID() int64 { + return c.id +} + +// UID 获取用户ID +func (c *serverConn) UID() int64 { + c.rw.RLock() + defer c.rw.RUnlock() + + return c.uid +} + +// Bind 绑定用户ID +func (c *serverConn) Bind(uid int64) { + c.rw.Lock() + defer c.rw.Unlock() + + c.uid = uid +} + +// Send 发送消息(同步) +func (c *serverConn) Send(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(msgType) == 0 { + msgType = append(msgType, TextMessage) + } + + switch msgType[0] { + case TextMessage, BinaryMessage: + return c.conn.WriteMessage(msgType[0], msg) + default: + return network.ErrIllegalMsgType + } +} + +// Push 发送消息(异步) +func (c *serverConn) Push(msg []byte, msgType ...int) error { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(msgType) == 0 { + msgType = append(msgType, TextMessage) + } + + switch msgType[0] { + case TextMessage, BinaryMessage: + c.chWrite <- chWrite{typ: dataPacket, msg: msg, msgType: msgType[0]} + default: + return network.ErrIllegalMsgType + } + + return nil +} + +// State 获取连接状态 +func (c *serverConn) State() network.ConnState { + return network.ConnState(atomic.LoadInt32(&c.state)) +} + +// Close 关闭连接 +func (c *serverConn) Close(isForce ...bool) error { + c.rw.Lock() + defer c.rw.Unlock() + + if err := c.checkState(); err != nil { + return err + } + + if len(isForce) > 0 && isForce[0] { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + } else { + atomic.StoreInt32(&c.state, int32(network.ConnHanged)) + c.chWrite <- chWrite{typ: closeSig} + <-c.done + } + + close(c.chWrite) + + err := c.conn.Close() + c.conn = nil + c.connMgr.recycle(c) + + return err +} + +// 关闭连接 +func (c *serverConn) close() { + atomic.StoreInt32(&c.state, int32(network.ConnClosed)) + + if c.connMgr.server.disconnectHandler != nil { + c.connMgr.server.disconnectHandler(c) + } +} + +// LocalIP 获取本地IP +func (c *serverConn) LocalIP() (string, error) { + addr, err := c.LocalAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// LocalAddr 获取本地地址 +func (c *serverConn) LocalAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.LocalAddr(), nil +} + +// RemoteIP 获取远端IP +func (c *serverConn) RemoteIP() (string, error) { + addr, err := c.RemoteAddr() + if err != nil { + return "", err + } + + return xnet.ExtractIP(addr) +} + +// RemoteAddr 获取远端地址 +func (c *serverConn) RemoteAddr() (net.Addr, error) { + c.rw.RLock() + defer c.rw.RUnlock() + + if err := c.checkState(); err != nil { + return nil, err + } + + return c.conn.RemoteAddr(), nil +} + +// 初始化连接 +func (c *serverConn) init(conn *websocket.Conn, cm *connMgr) { + c.id = cm.id + c.conn = conn + c.connMgr = cm + c.chWrite = make(chan chWrite, 256) + c.done = make(chan struct{}) + atomic.StoreInt32(&c.state, int32(network.ConnOpened)) + + if c.connMgr.server.connectHandler != nil { + c.connMgr.server.connectHandler(c) + } + + go c.read() + + go c.write() +} + +// 检测连接状态 +func (c *serverConn) checkState() error { + switch network.ConnState(atomic.LoadInt32(&c.state)) { + case network.ConnHanged: + return network.ErrConnectionHanged + case network.ConnClosed: + return network.ErrConnectionClosed + } + + return nil +} + +// 读取消息 +func (c *serverConn) read() { + defer c.close() + + for { + msgType, buf, err := c.conn.ReadMessage() + if err != nil { + break + } + + switch c.State() { + case network.ConnHanged: + continue + case network.ConnClosed: + return + } + + if c.connMgr.server.receiveHandler != nil { + c.connMgr.server.receiveHandler(c, buf, msgType) + } + } +} + +// 写入消息 +func (c *serverConn) write() { + for { + select { + case write, ok := <-c.chWrite: + if !ok { + return + } + + if write.typ == closeSig { + c.done <- struct{}{} + return + } + + if err := c.conn.WriteMessage(write.msgType, write.msg); err != nil { + log.Errorf("write message error: %v", err) + } + } + } +} diff --git a/network/ws/server_conn_mgr.go b/network/ws/server_conn_mgr.go new file mode 100644 index 00000000..dbd17d47 --- /dev/null +++ b/network/ws/server_conn_mgr.go @@ -0,0 +1,67 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/28 3:48 下午 + * @Desc: 连接管理器 + */ + +package ws + +import ( + "github.com/dobyte/due/network" + "sync" + + "github.com/gorilla/websocket" +) + +type connMgr struct { + rw sync.RWMutex // 连接读写锁 + id int64 // 连接ID + pool sync.Pool // 连接池 + conns map[*websocket.Conn]*serverConn // 连接集合 + server *server // 服务器 +} + +func newConnMgr(server *server) *connMgr { + return &connMgr{ + server: server, + conns: make(map[*websocket.Conn]*serverConn), + pool: sync.Pool{New: func() interface{} { return &serverConn{} }}, + } +} + +// 关闭连接 +func (cm *connMgr) close() { + cm.rw.Lock() + defer cm.rw.RUnlock() + + for _, conn := range cm.conns { + _ = conn.Close(false) + } +} + +// 分配连接 +func (cm *connMgr) allocate(c *websocket.Conn) error { + cm.rw.Lock() + defer cm.rw.Unlock() + + if len(cm.conns) >= cm.server.opts.maxConnNum { + return network.ErrTooManyConnection + } + + cm.id++ + conn := cm.pool.Get().(*serverConn) + conn.init(c, cm) + cm.conns[c] = conn + + return nil +} + +// 回收连接 +func (cm *connMgr) recycle(conn *serverConn) { + cm.rw.Lock() + defer cm.rw.Unlock() + + delete(cm.conns, conn.conn) + cm.pool.Put(conn) +} diff --git a/network/ws/server_options.go b/network/ws/server_options.go new file mode 100644 index 00000000..d4f1779b --- /dev/null +++ b/network/ws/server_options.go @@ -0,0 +1,43 @@ +package ws + +import ( + "net/http" +) + +type ServerOption func(o *serverOptions) + +type CheckOriginFunc func(r *http.Request) bool + +type serverOptions struct { + addr string // 监听地址 + maxConnNum int // 最大连接数 + certFile string // 证书文件 + keyFile string // 秘钥文件 + path string // 路径,默认为"/" + checkOrigin CheckOriginFunc // 跨域检测 +} + +// WithServerListenAddr 设置监听地址 +func WithServerListenAddr(addr string) ServerOption { + return func(o *serverOptions) { o.addr = addr } +} + +// WithServerMaxConnNum 设置连接的最大连接数 +func WithServerMaxConnNum(maxConnNum int) ServerOption { + return func(o *serverOptions) { o.maxConnNum = maxConnNum } +} + +// WithServerPath 设置Websocket的连接路径 +func WithServerPath(path string) ServerOption { + return func(o *serverOptions) { o.path = path } +} + +// WithServerCredentials 设置证书和秘钥 +func WithServerCredentials(certFile, keyFile string) ServerOption { + return func(o *serverOptions) { o.keyFile, o.certFile = keyFile, certFile } +} + +// WithServerCheckOrigin 设置Websocket跨域检测函数 +func WithServerCheckOrigin(checkOrigin CheckOriginFunc) ServerOption { + return func(o *serverOptions) { o.checkOrigin = checkOrigin } +} diff --git a/network/ws/server_test.go b/network/ws/server_test.go new file mode 100644 index 00000000..3552a7c3 --- /dev/null +++ b/network/ws/server_test.go @@ -0,0 +1,46 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2022/5/29 10:59 上午 + * @Desc: TODO + */ + +package ws_test + +import ( + "fmt" + "github.com/dobyte/due/network/ws" + "net/http" + "testing" + + "github.com/dobyte/due/network" +) + +func TestServer(t *testing.T) { + server := ws.NewServer( + ws.WithAddr(":8088"), + ws.WithMaxConnNum(5), + ws.WithCheckOrigin(func(r *http.Request) bool { return true }), + ) + + server.OnStart(func(s network.Server) { + fmt.Println("server is started") + }) + server.OnConnect(func(s network.Server, conn network.Conn) { + fmt.Printf("conn is opened, conn id: %d\n", conn.ID()) + }) + server.OnDisconnect(func(s network.Server, conn network.Conn) { + fmt.Printf("conn is closed, conn id: %d\n", conn.ID()) + }) + server.OnReceive(func(s network.Server, conn network.Conn, msg []byte, msgType int) { + fmt.Printf("receive msg from conn, conn id: %d, msg: %s\n", conn.ID(), string(msg)) + _ = conn.Close() + if err := conn.Send([]byte("hello world")); err != nil { + fmt.Println(err) + } + }) + + _ = server.Start() + + select {} +}