Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support feeling client disconnetion #1054

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7
github.com/bytedance/mockey v1.2.1
github.com/bytedance/sonic v1.8.1
github.com/cloudwego/netpoll v0.5.0
github.com/cloudwego/netpoll v0.6.0
github.com/fsnotify/fsnotify v1.5.4
github.com/tidwall/gjson v1.14.4
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU=
github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U=
github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
55 changes: 55 additions & 0 deletions pkg/app/server/hertz_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
c "github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -134,3 +135,57 @@ func TestHertz_Spin(t *testing.T) {

<-ch2
}

func TestWithSenseClientDisconnection(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6631")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6632"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
return ctx
}))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6632")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}
16 changes: 16 additions & 0 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option {
o.NoDefaultContentType = disable
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are two issues to note when using this option:
// 1. Warning: It only applies to netpoll.
// 2. After opening, the context.Context in the request will be cancelled.
//
// Example:
// server.Default(
// server.WithSenseClientDisconnection(true),
// )
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}
3 changes: 3 additions & 0 deletions pkg/app/server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestOptions(t *testing.T) {
WithBasePath("/"),
WithMaxRequestBodySize(2),
WithDisablePrintRoute(true),
WithSenseClientDisconnection(true),
WithNetwork("unix"),
WithExitWaitTime(time.Second),
WithMaxKeepBodySize(500),
Expand Down Expand Up @@ -93,6 +94,7 @@ func TestOptions(t *testing.T) {
assert.DeepEqual(t, opt.BasePath, "/")
assert.DeepEqual(t, opt.MaxRequestBodySize, 2)
assert.DeepEqual(t, opt.DisablePrintRoute, true)
assert.DeepEqual(t, opt.SenseClientDisconnection, true)
assert.DeepEqual(t, opt.Network, "unix")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second)
assert.DeepEqual(t, opt.MaxKeepBodySize, 500)
Expand Down Expand Up @@ -130,6 +132,7 @@ func TestDefaultOptions(t *testing.T) {
assert.DeepEqual(t, opt.GetOnly, false)
assert.DeepEqual(t, opt.DisableKeepalive, false)
assert.DeepEqual(t, opt.DisablePrintRoute, false)
assert.DeepEqual(t, opt.SenseClientDisconnection, false)
assert.DeepEqual(t, opt.Network, "tcp")
assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5)
assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024)
Expand Down
4 changes: 4 additions & 0 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type Options struct {
StreamRequestBody bool
NoDefaultServerHeader bool
DisablePrintRoute bool
SenseClientDisconnection bool
Network string
Addr string
BasePath string
Expand Down Expand Up @@ -203,6 +204,9 @@ func NewOptions(opts []Option) *Options {
// Disabled when set to True
DisablePrintRoute: false,

// The ability to sense client disconnection is disabled by default
SenseClientDisconnection: false,

// "tcp", "udp", "unix"(unix domain socket)
Network: defaultNetwork,

Expand Down
1 change: 1 addition & 0 deletions pkg/common/config/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func TestDefaultOptions(t *testing.T) {
assert.False(t, options.RemoveExtraSlash)
assert.True(t, options.UnescapePathValues)
assert.False(t, options.DisablePreParseMultipartForm)
assert.False(t, options.SenseClientDisconnection)
assert.DeepEqual(t, defaultNetwork, options.Network)
assert.DeepEqual(t, defaultAddr, options.Addr)
assert.DeepEqual(t, defaultMaxRequestBodySize, options.MaxRequestBodySize)
Expand Down
69 changes: 47 additions & 22 deletions pkg/network/netpoll/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,45 @@ func init() {
netpoll.SetLoggerOutput(io.Discard)
}

type ctxCancelKeyStruct struct{}

var ctxCancelKey = ctxCancelKeyStruct{}

func cancelContext(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
ctx = context.WithValue(ctx, ctxCancelKey, cancel)
return ctx
}

type transporter struct {
sync.RWMutex
network string
addr string
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
senseClientDisconnection bool
network string
addr string
keepAliveTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
listener net.Listener
eventLoop netpoll.EventLoop
listenConfig *net.ListenConfig
OnAccept func(conn net.Conn) context.Context
OnConnect func(ctx context.Context, conn network.Conn) context.Context
}

// For transporter switch
func NewTransporter(options *config.Options) network.Transporter {
return &transporter{
network: options.Network,
addr: options.Addr,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
senseClientDisconnection: options.SenseClientDisconnection,
network: options.Network,
addr: options.Addr,
keepAliveTimeout: options.KeepAliveTimeout,
readTimeout: options.ReadTimeout,
writeTimeout: options.WriteTimeout,
listener: nil,
eventLoop: nil,
listenConfig: options.ListenConfig,
OnAccept: options.OnAccept,
OnConnect: options.OnConnect,
}
}

Expand All @@ -88,10 +100,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
if t.writeTimeout > 0 {
conn.SetWriteTimeout(t.writeTimeout)
}
ctx := context.Background()
if t.OnAccept != nil {
return t.OnAccept(newConn(conn))
ctx = t.OnAccept(newConn(conn))
}
if t.senseClientDisconnection {
ctx = cancelContext(ctx)
}
return context.Background()
return ctx
}),
}

Expand All @@ -101,6 +117,15 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
}))
}

if t.senseClientDisconnection {
opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) {
cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc)
if cancelFunc != nil && ok {
cancelFunc()
}
}))
}

// Create EventLoop
t.Lock()
t.eventLoop, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error {
Expand Down
29 changes: 29 additions & 0 deletions pkg/network/netpoll/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,35 @@ func TestTransport(t *testing.T) {
assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1)
})

t.Run("TestSenseClientDisconnection", func(t *testing.T) {
var onReqFlag int32
li-jin-gou marked this conversation as resolved.
Show resolved Hide resolved
transporter := NewTransporter(&config.Options{
Addr: addr,
Network: nw,
SenseClientDisconnection: true,
})

go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error {
atomic.StoreInt32(&onReqFlag, 1)
time.Sleep(100 * time.Millisecond)
assert.DeepEqual(t, context.Canceled, ctx.Err())
return nil
})
defer transporter.Close()
time.Sleep(100 * time.Millisecond)

dial := NewDialer()
conn, err := dial.DialConnection(nw, addr, time.Second, nil)
assert.Nil(t, err)
_, err = conn.Write([]byte("123"))
assert.Nil(t, err)
err = conn.Close()
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)

assert.Assert(t, atomic.LoadInt32(&onReqFlag) == 1)
})

t.Run("TestListenConfig", func(t *testing.T) {
listenCfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) {
Expand Down
Loading