Skip to content

Commit

Permalink
feat: response support hijack for upgrade conn (#1214)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost authored Oct 21, 2024
1 parent 0e582ec commit 3477b03
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 6 deletions.
58 changes: 52 additions & 6 deletions pkg/protocol/http1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import (
"errors"
"io"
"net"
"runtime"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -689,31 +690,76 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo
return nil
})
}
zr.Release() //nolint:errcheck

if err != nil {
zr.Release() //nolint:errcheck
c.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
retry := !errors.Is(err, errs.ErrBodyTooLarge)
return retry, err
}

zr.Release() //nolint:errcheck

shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose()

if resp.Header.StatusCode() == consts.StatusSwitchingProtocols &&
bytes.Equal(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) {
// can not reuse connection in this case, it's no longer http1 protocol.
// set BodyStream for (*Response).Hijack
resp.SetBodyStream(newUpgradeConn(c, cc), -1)
return false, nil
}

// In stream mode, we still can close/release the connection immediately if there is no content on the wire.
if c.ResponseBodyStream && resp.BodyStream() != protocol.NoResponseBody {
return false, err
return false, nil
}

if shouldCloseConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return false, nil
}

var poolUpgradeConn = sync.Pool{
New: func() interface{} {
return &upgradeConn{}
},
}

type upgradeConn struct {
c *HostClient
cc *clientConn
}

func newUpgradeConn(c *HostClient, cc *clientConn) *upgradeConn {
p := poolUpgradeConn.Get().(*upgradeConn)
p.c = c
p.cc = cc
runtime.SetFinalizer(p, (*upgradeConn).gc)
return p
}

return false, err
// Read implements io.Reader
func (p *upgradeConn) Read(b []byte) (int, error) { return p.cc.c.Read(b) }

// Hijack returns underlying network.Conn. This method is called by (*Response).Hijack
func (p *upgradeConn) Hijack() (network.Conn, error) { return p.cc.c, nil }

// gc closes conn and reuse upgradeConn.
//
// It MUST be called only by go runtime to avoid concurenccy issue.
// For the 1st GC, it closes conn, and put upgradeConn back to pool
// For the 2nd GC, it will be recycled if it's still in pool
func (p *upgradeConn) gc() error {
if p.c != nil {
runtime.SetFinalizer(p, nil)
p.c.closeConn(p.cc)
p.c = nil
p.cc = nil
poolUpgradeConn.Put(p)
}
return nil
}

func (c *HostClient) Close() error {
Expand Down
59 changes: 59 additions & 0 deletions pkg/protocol/http1/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -311,6 +312,64 @@ func TestDoNonNilReqResp1(t *testing.T) {
assert.NotNil(t, err)
}

func TestConnUpgrade(t *testing.T) {
ln, _ := net.Listen("tcp", "localhost:0")
defer ln.Close()
svr := http.Server{}
svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
return
}
conn, rw, err := hj.Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
_, err = rw.WriteString("HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n")
assert.Nil(t, err)
assert.Nil(t, rw.Flush())
b := make([]byte, 100)
for { // echo with "echo:" prefix
n, err := rw.Read(b)
if err != nil {
return
}
_, err = rw.Write([]byte("echo:" + string(b[:n])))
if err != nil {
return
}
_ = rw.Flush()
}
})
go svr.Serve(ln)

c := &HostClient{
Addr: ln.Addr().String(),
ClientOptions: &ClientOptions{},
}
req := protocol.AcquireRequest()
req.SetRequestURI("http://" + ln.Addr().String() + "/")
resp := protocol.AcquireResponse()
retry, err := c.doNonNilReqResp(req, resp)
assert.False(t, retry)
assert.Nil(t, err)
assert.DeepEqual(t, resp.StatusCode(), 101)

s := resp.BodyStream()
assert.NotNil(t, s)
conn, err := resp.Hijack()
assert.Nil(t, err)

b := make([]byte, 100)
_, _ = conn.Write(append(b[:0], "hello"...))
n, err := s.Read(b) // same as conn.Read
assert.Nil(t, err)
assert.DeepEqual(t, string(b[:n]), "echo:hello")
}

func TestWriteTimeoutPriority(t *testing.T) {
c := &HostClient{
ClientOptions: &ClientOptions{
Expand Down
20 changes: 20 additions & 0 deletions pkg/protocol/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
package protocol

import (
"errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -346,6 +347,25 @@ func (resp *Response) BodyStream() io.Reader {
return resp.bodyStream
}

// Hijack returns the underlying network.Conn if available.
//
// It's only available when StatusCode() == 101 and "Connection: Upgrade",
// coz Hertz will NOT reuse connection in this case,
// then make it optional for users to implement their own protocols.
//
// The most common scenario is used with github.com/hertz-contrib/websocket
func (resp *Response) Hijack() (network.Conn, error) {
if resp.bodyStream != nil {
h, ok := resp.bodyStream.(interface {
Hijack() (network.Conn, error)
})
if ok {
return h.Hijack()
}
}
return nil, errors.New("not available")
}

// AppendBody appends p to response body.
//
// It is safe re-using p after the function returns.
Expand Down
19 changes: 19 additions & 0 deletions pkg/protocol/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ package protocol

import (
"bytes"
"errors"
"fmt"
"math"
"reflect"
Expand All @@ -52,6 +53,7 @@ import (
"github.com/cloudwego/hertz/pkg/common/compress"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/test/mock"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol/consts"
)

Expand Down Expand Up @@ -294,3 +296,20 @@ func TestResponse_HijackWriter(t *testing.T) {
resp.GetHijackWriter().Finalize()
assert.True(t, isFinal)
}

type HijackerFunc func() (network.Conn, error)

func (h HijackerFunc) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") }
func (h HijackerFunc) Hijack() (network.Conn, error) { return h() }

func TestResponse_Hijack(t *testing.T) {
resp := AcquireResponse()
defer ReleaseResponse(resp)

_, err := resp.Hijack()
assert.NotNil(t, err)

resp.SetBodyStream(HijackerFunc(func() (network.Conn, error) { return nil, nil }), -1)
_, err = resp.Hijack()
assert.Nil(t, err)
}

0 comments on commit 3477b03

Please sign in to comment.