diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index f7b04ba25..149029db0 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -48,6 +48,7 @@ import ( "errors" "io" "net" + "runtime" "strings" "sync" "sync/atomic" @@ -689,22 +690,27 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo return nil }) } + zr.Release() //nolint:errcheck if err != nil { - zr.Release() //nolint:errcheck c.closeConn(cc) // Don't retry in case of ErrBodyTooLarge since we will just get the same again. retry := !errors.Is(err, errs.ErrBodyTooLarge) return retry, err } - - zr.Release() //nolint:errcheck - shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose() + if resp.Header.StatusCode() == consts.StatusSwitchingProtocols && + bytes.Equal(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) { + // can not reuse connection in this case, it's no longer http1 protocol. + // set BodyStream for (*Response).Hijack + resp.SetBodyStream(newUpgradeConn(c, cc), -1) + return false, nil + } + // In stream mode, we still can close/release the connection immediately if there is no content on the wire. if c.ResponseBodyStream && resp.BodyStream() != protocol.NoResponseBody { - return false, err + return false, nil } if shouldCloseConn { @@ -712,8 +718,48 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo } else { c.releaseConn(cc) } + return false, nil +} + +var poolUpgradeConn = sync.Pool{ + New: func() interface{} { + return &upgradeConn{} + }, +} + +type upgradeConn struct { + c *HostClient + cc *clientConn +} + +func newUpgradeConn(c *HostClient, cc *clientConn) *upgradeConn { + p := poolUpgradeConn.Get().(*upgradeConn) + p.c = c + p.cc = cc + runtime.SetFinalizer(p, (*upgradeConn).gc) + return p +} - return false, err +// Read implements io.Reader +func (p *upgradeConn) Read(b []byte) (int, error) { return p.cc.c.Read(b) } + +// Hijack returns underlying network.Conn. This method is called by (*Response).Hijack +func (p *upgradeConn) Hijack() (network.Conn, error) { return p.cc.c, nil } + +// gc closes conn and reuse upgradeConn. +// +// It MUST be called only by go runtime to avoid concurenccy issue. +// For the 1st GC, it closes conn, and put upgradeConn back to pool +// For the 2nd GC, it will be recycled if it's still in pool +func (p *upgradeConn) gc() error { + if p.c != nil { + runtime.SetFinalizer(p, nil) + p.c.closeConn(p.cc) + p.c = nil + p.cc = nil + poolUpgradeConn.Put(p) + } + return nil } func (c *HostClient) Close() error { diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 8c0869dde..2ecb764a6 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -49,6 +49,7 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "strings" "sync" "sync/atomic" @@ -311,6 +312,64 @@ func TestDoNonNilReqResp1(t *testing.T) { assert.NotNil(t, err) } +func TestConnUpgrade(t *testing.T) { + ln, _ := net.Listen("tcp", "localhost:0") + defer ln.Close() + svr := http.Server{} + svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + conn, rw, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + _, err = rw.WriteString("HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n") + assert.Nil(t, err) + assert.Nil(t, rw.Flush()) + b := make([]byte, 100) + for { // echo with "echo:" prefix + n, err := rw.Read(b) + if err != nil { + return + } + _, err = rw.Write([]byte("echo:" + string(b[:n]))) + if err != nil { + return + } + _ = rw.Flush() + } + }) + go svr.Serve(ln) + + c := &HostClient{ + Addr: ln.Addr().String(), + ClientOptions: &ClientOptions{}, + } + req := protocol.AcquireRequest() + req.SetRequestURI("http://" + ln.Addr().String() + "/") + resp := protocol.AcquireResponse() + retry, err := c.doNonNilReqResp(req, resp) + assert.False(t, retry) + assert.Nil(t, err) + assert.DeepEqual(t, resp.StatusCode(), 101) + + s := resp.BodyStream() + assert.NotNil(t, s) + conn, err := resp.Hijack() + assert.Nil(t, err) + + b := make([]byte, 100) + _, _ = conn.Write(append(b[:0], "hello"...)) + n, err := s.Read(b) // same as conn.Read + assert.Nil(t, err) + assert.DeepEqual(t, string(b[:n]), "echo:hello") +} + func TestWriteTimeoutPriority(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index 8beb38597..ccace1399 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -42,6 +42,7 @@ package protocol import ( + "errors" "io" "net" "sync" @@ -346,6 +347,25 @@ func (resp *Response) BodyStream() io.Reader { return resp.bodyStream } +// Hijack returns the underlying network.Conn if available. +// +// It's only available when StatusCode() == 101 and "Connection: Upgrade", +// coz Hertz will NOT reuse connection in this case, +// then make it optional for users to implement their own protocols. +// +// The most common scenario is used with github.com/hertz-contrib/websocket +func (resp *Response) Hijack() (network.Conn, error) { + if resp.bodyStream != nil { + h, ok := resp.bodyStream.(interface { + Hijack() (network.Conn, error) + }) + if ok { + return h.Hijack() + } + } + return nil, errors.New("not available") +} + // AppendBody appends p to response body. // // It is safe re-using p after the function returns. diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 20a18ffce..6b09c2e19 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -43,6 +43,7 @@ package protocol import ( "bytes" + "errors" "fmt" "math" "reflect" @@ -52,6 +53,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" + "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -294,3 +296,20 @@ func TestResponse_HijackWriter(t *testing.T) { resp.GetHijackWriter().Finalize() assert.True(t, isFinal) } + +type HijackerFunc func() (network.Conn, error) + +func (h HijackerFunc) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } +func (h HijackerFunc) Hijack() (network.Conn, error) { return h() } + +func TestResponse_Hijack(t *testing.T) { + resp := AcquireResponse() + defer ReleaseResponse(resp) + + _, err := resp.Hijack() + assert.NotNil(t, err) + + resp.SetBodyStream(HijackerFunc(func() (network.Conn, error) { return nil, nil }), -1) + _, err = resp.Hijack() + assert.Nil(t, err) +}