From 8bdaf3faaf298ba3c5514dfca3f76f0da2b08379 Mon Sep 17 00:00:00 2001 From: kinggo Date: Mon, 12 Feb 2024 16:48:53 +0800 Subject: [PATCH] feat: add cancelContext --- pkg/app/server/hertz_test.go | 43 ------------------------ pkg/app/server/hertz_unix_test.go | 54 +++++++++++++++++++++++++++++++ pkg/app/server/option.go | 32 +++++++++--------- pkg/network/netpoll/transport.go | 26 +++++++++++---- pkg/protocol/http1/proxy/proxy.go | 2 -- 5 files changed, 89 insertions(+), 68 deletions(-) diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 9690c57e7..4a9cee3d0 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -43,7 +43,6 @@ import ( "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" - "github.com/cloudwego/hertz/pkg/network/netpoll" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -1093,45 +1092,3 @@ func TestWithDisableDefaultContentType(t *testing.T) { r, _ := hc.Get("http://127.0.0.1:8324") //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Content-Type")) } - -func TestWithSenseClientDisconnection(t *testing.T) { - h := New( - WithHostPorts("localhost:8327"), - WithSenseClientDisconnection(true), - ) - var wg sync.WaitGroup - wg.Add(1) - h.GET("/test", func(c context.Context, ctx *app.RequestContext) { - defer wg.Done() - select { - case <-c.Done(): - return - case <-time.After(time.Second): - t.Fatal("cancel context failed") - } - }) - go h.Spin() - time.Sleep(100 * time.Millisecond) - - dail := netpoll.NewDialer() - conn, err := dail.DialConnection("tcp", "127.0.0.1:8327", 0, nil) - assert.Nil(t, err) - tr := &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return conn, nil - }, - } - hc := http.Client{ - Timeout: time.Second, - Transport: tr, - } - - go func() { - _, err := hc.Get("http://127.0.0.1:8327/test") - assert.NotNil(t, err) - }() - time.Sleep(100 * time.Millisecond) - err = conn.Close() - assert.Nil(t, err) - wg.Wait() -} diff --git a/pkg/app/server/hertz_unix_test.go b/pkg/app/server/hertz_unix_test.go index b37ddfbdf..36653997e 100644 --- a/pkg/app/server/hertz_unix_test.go +++ b/pkg/app/server/hertz_unix_test.go @@ -30,6 +30,8 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/pkg/network" + "github.com/cloudwego/hertz/pkg/app" c "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -134,3 +136,55 @@ func TestHertz_Spin(t *testing.T) { <-ch2 } + +func TestWithSenseClientDisconnection(t *testing.T) { + var closeFlag int32 + h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true)) + h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + assert.DeepEqual(t, "aa", string(ctx.Host())) + ch := make(chan struct{}) + select { + case <-c.Done(): + atomic.StoreInt32(&closeFlag, 1) + assert.DeepEqual(t, context.Canceled, c.Err()) + case <-ch: + } + }) + go h.Spin() + time.Sleep(time.Second) + con, err := net.Dial("tcp", "127.0.0.1:6631") + assert.Nil(t, err) + _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) + assert.Nil(t, err) + time.Sleep(time.Second) + assert.Nil(t, con.Close()) + time.Sleep(time.Second) + atomic.StoreInt32(&closeFlag, 1) +} + +func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { + var closeFlag int32 + h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { + return ctx + })) + h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { + assert.DeepEqual(t, "aa", string(ctx.Host())) + ch := make(chan struct{}) + select { + case <-c.Done(): + atomic.StoreInt32(&closeFlag, 1) + assert.DeepEqual(t, context.Canceled, c.Err()) + case <-ch: + } + }) + go h.Spin() + time.Sleep(time.Second) + con, err := net.Dial("tcp", "127.0.0.1:6631") + assert.Nil(t, err) + _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) + assert.Nil(t, err) + time.Sleep(time.Second) + assert.Nil(t, con.Close()) + time.Sleep(time.Second) + atomic.StoreInt32(&closeFlag, 1) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index e1d93b9c2..18f184379 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -332,22 +332,6 @@ func WithDisablePrintRoute(b bool) config.Option { }} } -// WithSenseClientDisconnection sets the ability to sense client disconnections. -// If we don't set it, it will default to false. -// There are two issues to note when using this option: -// 1. Warning: It only applies to netpoll. -// 2. After opening, the context.Context in the request will be cancelled. -// -// Example: -// server.Default( -// server.WithSenseClientDisconnection(true), -// ) -func WithSenseClientDisconnection(b bool) config.Option { - return config.Option{F: func(o *config.Options) { - o.SenseClientDisconnection = b - }} -} - // WithOnAccept sets the callback function when a new connection is accepted but cannot // receive data in netpoll. In go net, it will be called before converting tls connection func WithOnAccept(fn func(conn net.Conn) context.Context) config.Option { @@ -410,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option { o.NoDefaultContentType = disable }} } + +// WithSenseClientDisconnection sets the ability to sense client disconnections. +// If we don't set it, it will default to false. +// There are two issues to note when using this option: +// 1. Warning: It only applies to netpoll. +// 2. After opening, the context.Context in the request will be cancelled. +// +// Example: +// server.Default( +// server.WithSenseClientDisconnection(true), +// ) +func WithSenseClientDisconnection(b bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.SenseClientDisconnection = b + }} +} diff --git a/pkg/network/netpoll/transport.go b/pkg/network/netpoll/transport.go index 6402bd146..aaa195965 100644 --- a/pkg/network/netpoll/transport.go +++ b/pkg/network/netpoll/transport.go @@ -36,6 +36,14 @@ func init() { netpoll.SetLoggerOutput(io.Discard) } +const ctxCancelKey = "ctxCancelKey" + +func cancelContext(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + ctx = context.WithValue(ctx, ctxCancelKey, cancel) + return ctx +} + type transporter struct { sync.RWMutex senseClientDisconnection bool @@ -99,18 +107,22 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { if t.OnConnect != nil { opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context { + if t.senseClientDisconnection { + ctx = cancelContext(ctx) + } return t.OnConnect(ctx, newConn(conn)) })) } - const ctxKey = "ctxKey" if t.senseClientDisconnection { - opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context { - ctx, cancel := context.WithCancel(ctx) - return context.WithValue(ctx, ctxKey, cancel) - }), netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { - cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc) - if cancelFunc != nil { + if t.OnConnect == nil { + opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context { + return cancelContext(ctx) + })) + } + opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { + cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc) + if cancelFunc != nil && ok { cancelFunc() } })) diff --git a/pkg/protocol/http1/proxy/proxy.go b/pkg/protocol/http1/proxy/proxy.go index f8bae7608..2b243ff04 100644 --- a/pkg/protocol/http1/proxy/proxy.go +++ b/pkg/protocol/http1/proxy/proxy.go @@ -81,13 +81,11 @@ func SetupProxy(conn network.Conn, addr string, proxyURI *protocol.URI, tlsConfi defer close(didReadResponse) err = reqI.Write(connectReq, conn) - if err != nil { return } err = conn.Flush() - if err != nil { return }