Skip to content

Commit

Permalink
feat: add cancelContext
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Feb 12, 2024
1 parent c63f652 commit dfa79d0
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 70 deletions.
43 changes: 0 additions & 43 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/netpoll"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand Down Expand Up @@ -1093,45 +1092,3 @@ func TestWithDisableDefaultContentType(t *testing.T) {
r, _ := hc.Get("http://127.0.0.1:8324") //nolint:errcheck
assert.DeepEqual(t, "", r.Header.Get("Content-Type"))
}

func TestWithSenseClientDisconnection(t *testing.T) {
h := New(
WithHostPorts("localhost:8327"),
WithSenseClientDisconnection(true),
)
var wg sync.WaitGroup
wg.Add(1)
h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
defer wg.Done()
select {
case <-c.Done():
return
case <-time.After(time.Second):
t.Fatal("cancel context failed")
}
})
go h.Spin()
time.Sleep(100 * time.Millisecond)

dail := netpoll.NewDialer()
conn, err := dail.DialConnection("tcp", "127.0.0.1:8327", 0, nil)
assert.Nil(t, err)
tr := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return conn, nil
},
}
hc := http.Client{
Timeout: time.Second,
Transport: tr,
}

go func() {
_, err := hc.Get("http://127.0.0.1:8327/test")
assert.NotNil(t, err)
}()
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.Nil(t, err)
wg.Wait()
}
53 changes: 53 additions & 0 deletions pkg/app/server/hertz_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
c "github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -134,3 +135,55 @@ func TestHertz_Spin(t *testing.T) {

<-ch2
}

func TestWithSenseClientDisconnection(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6631")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6632"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
return ctx
}))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6632")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}
32 changes: 16 additions & 16 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,22 +332,6 @@ func WithDisablePrintRoute(b bool) config.Option {
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are two issues to note when using this option:
// 1. Warning: It only applies to netpoll.
// 2. After opening, the context.Context in the request will be cancelled.
//
// Example:
// server.Default(
// server.WithSenseClientDisconnection(true),
// )
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}

// WithOnAccept sets the callback function when a new connection is accepted but cannot
// receive data in netpoll. In go net, it will be called before converting tls connection
func WithOnAccept(fn func(conn net.Conn) context.Context) config.Option {
Expand Down Expand Up @@ -410,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option {
o.NoDefaultContentType = disable
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are two issues to note when using this option:
// 1. Warning: It only applies to netpoll.
// 2. After opening, the context.Context in the request will be cancelled.
//
// Example:
// server.Default(
// server.WithSenseClientDisconnection(true),
// )
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}
26 changes: 17 additions & 9 deletions pkg/network/netpoll/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ func init() {
netpoll.SetLoggerOutput(io.Discard)
}

const ctxCancelKey = "ctxCancelKey"

func cancelContext(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
ctx = context.WithValue(ctx, ctxCancelKey, cancel)
return ctx
}

type transporter struct {
sync.RWMutex
senseClientDisconnection bool
Expand Down Expand Up @@ -90,10 +98,14 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
if t.writeTimeout > 0 {
conn.SetWriteTimeout(t.writeTimeout)
}
ctx := context.Background()
if t.OnAccept != nil {
return t.OnAccept(newConn(conn))
ctx = t.OnAccept(newConn(conn))
}
if t.senseClientDisconnection {
ctx = cancelContext(ctx)
}
return context.Background()
return ctx
}),
}

Expand All @@ -103,14 +115,10 @@ func (t *transporter) ListenAndServe(onReq network.OnData) (err error) {
}))
}

const ctxKey = "ctxKey"
if t.senseClientDisconnection {
opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, connection netpoll.Connection) context.Context {
ctx, cancel := context.WithCancel(ctx)
return context.WithValue(ctx, ctxKey, cancel)
}), netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) {
cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc)
if cancelFunc != nil {
opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) {
cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc)
if cancelFunc != nil && ok {
cancelFunc()
}
}))
Expand Down
2 changes: 0 additions & 2 deletions pkg/protocol/http1/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ func SetupProxy(conn network.Conn, addr string, proxyURI *protocol.URI, tlsConfi
defer close(didReadResponse)

err = reqI.Write(connectReq, conn)

if err != nil {
return
}

err = conn.Flush()

if err != nil {
return
}
Expand Down

0 comments on commit dfa79d0

Please sign in to comment.