diff --git a/go.mod b/go.mod index 3619b833d..36f7e5132 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 github.com/bytedance/mockey v1.2.1 github.com/bytedance/sonic v1.8.1 - github.com/cloudwego/netpoll v0.5.0 + github.com/cloudwego/netpoll v0.6.0 github.com/fsnotify/fsnotify v1.5.4 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum index 3cf3c2f96..4b60e994d 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= -github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U= +github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/app/server/hertz_unix_test.go b/pkg/app/server/hertz_unix_test.go index b37ddfbdf..b1f7d700c 100644 --- a/pkg/app/server/hertz_unix_test.go +++ b/pkg/app/server/hertz_unix_test.go @@ -34,6 +34,7 @@ import ( c "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol/consts" "golang.org/x/sys/unix" @@ -134,3 +135,57 @@ func TestHertz_Spin(t *testing.T) { <-ch2 } + +func TestWithSenseClientDisconnection(t *testing.T) { + var closeFlag int32 + h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true)) + h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + assert.DeepEqual(t, "aa", string(ctx.Host())) + ch := make(chan struct{}) + select { + case <-c.Done(): + atomic.StoreInt32(&closeFlag, 1) + assert.DeepEqual(t, context.Canceled, c.Err()) + case <-ch: + } + }) + go h.Spin() + time.Sleep(time.Second) + con, err := net.Dial("tcp", "127.0.0.1:6631") + assert.Nil(t, err) + _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) + assert.Nil(t, err) + time.Sleep(time.Second) + assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) + assert.Nil(t, con.Close()) + time.Sleep(time.Second) + assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) +} + +func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { + var closeFlag int32 + h := New(WithHostPorts("127.0.0.1:6632"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { + return ctx + })) + h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + assert.DeepEqual(t, "aa", string(ctx.Host())) + ch := make(chan struct{}) + select { + case <-c.Done(): + atomic.StoreInt32(&closeFlag, 1) + assert.DeepEqual(t, context.Canceled, c.Err()) + case <-ch: + } + }) + go h.Spin() + time.Sleep(time.Second) + con, err := net.Dial("tcp", "127.0.0.1:6632") + assert.Nil(t, err) + _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) + assert.Nil(t, err) + time.Sleep(time.Second) + assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) + assert.Nil(t, con.Close()) + time.Sleep(time.Second) + assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index e7970348c..18f184379 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -394,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option { o.NoDefaultContentType = disable }} } + +// WithSenseClientDisconnection sets the ability to sense client disconnections. +// If we don't set it, it will default to false. +// There are two issues to note when using this option: +// 1. Warning: It only applies to netpoll. +// 2. After opening, the context.Context in the request will be cancelled. +// +// Example: +// server.Default( +// server.WithSenseClientDisconnection(true), +// ) +func WithSenseClientDisconnection(b bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.SenseClientDisconnection = b + }} +} diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index f5d7f7b32..bc43e7cd2 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -61,6 +61,7 @@ func TestOptions(t *testing.T) { WithBasePath("/"), WithMaxRequestBodySize(2), WithDisablePrintRoute(true), + WithSenseClientDisconnection(true), WithNetwork("unix"), WithExitWaitTime(time.Second), WithMaxKeepBodySize(500), @@ -93,6 +94,7 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.BasePath, "/") assert.DeepEqual(t, opt.MaxRequestBodySize, 2) assert.DeepEqual(t, opt.DisablePrintRoute, true) + assert.DeepEqual(t, opt.SenseClientDisconnection, true) assert.DeepEqual(t, opt.Network, "unix") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second) assert.DeepEqual(t, opt.MaxKeepBodySize, 500) @@ -130,6 +132,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, opt.GetOnly, false) assert.DeepEqual(t, opt.DisableKeepalive, false) assert.DeepEqual(t, opt.DisablePrintRoute, false) + assert.DeepEqual(t, opt.SenseClientDisconnection, false) assert.DeepEqual(t, opt.Network, "tcp") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5) assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024) diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 417955fc9..958d9b3a3 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -63,6 +63,7 @@ type Options struct { StreamRequestBody bool NoDefaultServerHeader bool DisablePrintRoute bool + SenseClientDisconnection bool Network string Addr string BasePath string @@ -203,6 +204,9 @@ func NewOptions(opts []Option) *Options { // Disabled when set to True DisablePrintRoute: false, + // The ability to sense client disconnection is disabled by default + SenseClientDisconnection: false, + // "tcp", "udp", "unix"(unix domain socket) Network: defaultNetwork, diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 67fcab796..b836a3a1a 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -39,6 +39,7 @@ func TestDefaultOptions(t *testing.T) { assert.False(t, options.RemoveExtraSlash) assert.True(t, options.UnescapePathValues) assert.False(t, options.DisablePreParseMultipartForm) + assert.False(t, options.SenseClientDisconnection) assert.DeepEqual(t, defaultNetwork, options.Network) assert.DeepEqual(t, defaultAddr, options.Addr) assert.DeepEqual(t, defaultMaxRequestBodySize, options.MaxRequestBodySize) diff --git a/pkg/network/netpoll/transport.go b/pkg/network/netpoll/transport.go index 17829cb83..7450292ce 100644 --- a/pkg/network/netpoll/transport.go +++ b/pkg/network/netpoll/transport.go @@ -36,33 +36,45 @@ func init() { netpoll.SetLoggerOutput(io.Discard) } +type ctxCancelKeyStruct struct{} + +var ctxCancelKey = ctxCancelKeyStruct{} + +func cancelContext(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + ctx = context.WithValue(ctx, ctxCancelKey, cancel) + return ctx +} + type transporter struct { sync.RWMutex - network string - addr string - keepAliveTimeout time.Duration - readTimeout time.Duration - writeTimeout time.Duration - listener net.Listener - eventLoop netpoll.EventLoop - listenConfig *net.ListenConfig - OnAccept func(conn net.Conn) context.Context - OnConnect func(ctx context.Context, conn network.Conn) context.Context + senseClientDisconnection bool + network string + addr string + keepAliveTimeout time.Duration + readTimeout time.Duration + writeTimeout time.Duration + listener net.Listener + eventLoop netpoll.EventLoop + listenConfig *net.ListenConfig + OnAccept func(conn net.Conn) context.Context + OnConnect func(ctx context.Context, conn network.Conn) context.Context } // For transporter switch func NewTransporter(options *config.Options) network.Transporter { return &transporter{ - network: options.Network, - addr: options.Addr, - keepAliveTimeout: options.KeepAliveTimeout, - readTimeout: options.ReadTimeout, - writeTimeout: options.WriteTimeout, - listener: nil, - eventLoop: nil, - listenConfig: options.ListenConfig, - OnAccept: options.OnAccept, - OnConnect: options.OnConnect, + senseClientDisconnection: options.SenseClientDisconnection, + network: options.Network, + addr: options.Addr, + keepAliveTimeout: options.KeepAliveTimeout, + readTimeout: options.ReadTimeout, + writeTimeout: options.WriteTimeout, + listener: nil, + eventLoop: nil, + listenConfig: options.ListenConfig, + OnAccept: options.OnAccept, + OnConnect: options.OnConnect, } } @@ -88,10 +100,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { if t.writeTimeout > 0 { conn.SetWriteTimeout(t.writeTimeout) } + ctx := context.Background() if t.OnAccept != nil { - return t.OnAccept(newConn(conn)) + ctx = t.OnAccept(newConn(conn)) + } + if t.senseClientDisconnection { + ctx = cancelContext(ctx) } - return context.Background() + return ctx }), } @@ -101,6 +117,15 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { })) } + if t.senseClientDisconnection { + opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { + cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc) + if cancelFunc != nil && ok { + cancelFunc() + } + })) + } + // Create EventLoop t.Lock() t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error { diff --git a/pkg/network/netpoll/transport_test.go b/pkg/network/netpoll/transport_test.go index d8a06090c..53239209e 100644 --- a/pkg/network/netpoll/transport_test.go +++ b/pkg/network/netpoll/transport_test.go @@ -69,6 +69,35 @@ func TestTransport(t *testing.T) { assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1) }) + t.Run("TestSenseClientDisconnection", func(t *testing.T) { + var onReqFlag int32 + transporter := NewTransporter(&config.Options{ + Addr: addr, + Network: nw, + SenseClientDisconnection: true, + }) + + go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { + atomic.StoreInt32(&onReqFlag, 1) + time.Sleep(100 * time.Millisecond) + assert.DeepEqual(t, context.Canceled, ctx.Err()) + return nil + }) + defer transporter.Close() + time.Sleep(100 * time.Millisecond) + + dial := NewDialer() + conn, err := dial.DialConnection(nw, addr, time.Second, nil) + assert.Nil(t, err) + _, err = conn.Write([]byte("123")) + assert.Nil(t, err) + err = conn.Close() + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) + + assert.Assert(t, atomic.LoadInt32(&onReqFlag) == 1) + }) + t.Run("TestListenConfig", func(t *testing.T) { listenCfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) {