diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index fcf380485..0a03792f8 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -332,6 +332,24 @@ func WithDisablePrintRoute(b bool) config.Option { }} } +// WithSenseClientDisconnection sets the ability to sense client disconnections. +// If we don't set it, it will default to false. +// There are three issues to note when using this option: +// 1. It only applies to netpoll. +// 2. It needs to be used in conjunction with WithOnAccept. +// Examples: +// server.Default( +// server.WithSenseClientDisconnection(true), +// server.WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { +// return ctx +// })) +// 3. The cost is high after opening, please choose carefully. +func WithSenseClientDisconnection(b bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.SenseClientDisconnection = b + }} +} + // WithOnAccept sets the callback function when a new connection is accepted but cannot // receive data in netpoll. In go net, it will be called before converting tls connection func WithOnAccept(fn func(conn net.Conn) context.Context) config.Option { diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index f5d7f7b32..bc43e7cd2 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -61,6 +61,7 @@ func TestOptions(t *testing.T) { WithBasePath("/"), WithMaxRequestBodySize(2), WithDisablePrintRoute(true), + WithSenseClientDisconnection(true), WithNetwork("unix"), WithExitWaitTime(time.Second), WithMaxKeepBodySize(500), @@ -93,6 +94,7 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.BasePath, "/") assert.DeepEqual(t, opt.MaxRequestBodySize, 2) assert.DeepEqual(t, opt.DisablePrintRoute, true) + assert.DeepEqual(t, opt.SenseClientDisconnection, true) assert.DeepEqual(t, opt.Network, "unix") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second) assert.DeepEqual(t, opt.MaxKeepBodySize, 500) @@ -130,6 +132,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, opt.GetOnly, false) assert.DeepEqual(t, opt.DisableKeepalive, false) assert.DeepEqual(t, opt.DisablePrintRoute, false) + assert.DeepEqual(t, opt.SenseClientDisconnection, false) assert.DeepEqual(t, opt.Network, "tcp") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5) assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024) diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 048fb366f..bc6c24c51 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -61,6 +61,7 @@ type Options struct { StreamRequestBody bool NoDefaultServerHeader bool DisablePrintRoute bool + SenseClientDisconnection bool Network string Addr string BasePath string @@ -195,6 +196,9 @@ func NewOptions(opts []Option) *Options { // Disabled when set to True DisablePrintRoute: false, + // The ability to sense client disconnection is disabled by default + SenseClientDisconnection: false, + // "tcp", "udp", "unix"(unix domain socket) Network: defaultNetwork, diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 67fcab796..315ec9c97 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -37,6 +37,7 @@ func TestDefaultOptions(t *testing.T) { assert.False(t, options.HandleMethodNotAllowed) assert.False(t, options.UseRawPath) assert.False(t, options.RemoveExtraSlash) + assert.False(t, options.SenseClientDisconnection) assert.True(t, options.UnescapePathValues) assert.False(t, options.DisablePreParseMultipartForm) assert.DeepEqual(t, defaultNetwork, options.Network) diff --git a/pkg/network/netpoll/transport.go b/pkg/network/netpoll/transport.go index 17829cb83..72e69ca7f 100644 --- a/pkg/network/netpoll/transport.go +++ b/pkg/network/netpoll/transport.go @@ -38,31 +38,33 @@ func init() { type transporter struct { sync.RWMutex - network string - addr string - keepAliveTimeout time.Duration - readTimeout time.Duration - writeTimeout time.Duration - listener net.Listener - eventLoop netpoll.EventLoop - listenConfig *net.ListenConfig - OnAccept func(conn net.Conn) context.Context - OnConnect func(ctx context.Context, conn network.Conn) context.Context + network string + addr string + senseClientDisconnection bool + keepAliveTimeout time.Duration + readTimeout time.Duration + writeTimeout time.Duration + listener net.Listener + eventLoop netpoll.EventLoop + listenConfig *net.ListenConfig + OnAccept func(conn net.Conn) context.Context + OnConnect func(ctx context.Context, conn network.Conn) context.Context } // For transporter switch func NewTransporter(options *config.Options) network.Transporter { return &transporter{ - network: options.Network, - addr: options.Addr, - keepAliveTimeout: options.KeepAliveTimeout, - readTimeout: options.ReadTimeout, - writeTimeout: options.WriteTimeout, - listener: nil, - eventLoop: nil, - listenConfig: options.ListenConfig, - OnAccept: options.OnAccept, - OnConnect: options.OnConnect, + network: options.Network, + addr: options.Addr, + senseClientDisconnection: options.SenseClientDisconnection, + keepAliveTimeout: options.KeepAliveTimeout, + readTimeout: options.ReadTimeout, + writeTimeout: options.WriteTimeout, + listener: nil, + eventLoop: nil, + listenConfig: options.ListenConfig, + OnAccept: options.OnAccept, + OnConnect: options.OnConnect, } } @@ -97,6 +99,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { if t.OnConnect != nil { opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context { + if t.senseClientDisconnection { + ctx, cancel := context.WithCancel(ctx) + conn.AddCloseCallback(func(connection netpoll.Connection) error { + cancel() + return nil + }) + return t.OnConnect(ctx, newConn(conn)) + } return t.OnConnect(ctx, newConn(conn)) })) }