Skip to content

Commit

Permalink
feat(http1): add env to disable request context pool (#1173)
Browse files Browse the repository at this point in the history
  • Loading branch information
welkeyever authored Aug 22, 2024
1 parent a64f390 commit c29f150
Show file tree
Hide file tree
Showing 8 changed files with 330 additions and 72 deletions.
234 changes: 185 additions & 49 deletions pkg/app/client/client_test.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/text_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) {
return &interfaceDecoder{}, nil
}

return nil, fmt.Errorf("unsupported type " + rt.String())
return nil, fmt.Errorf("unsupported type %s", rt.String())
}

type boolDecoder struct{}
Expand Down
37 changes: 29 additions & 8 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ func TestLoadHTMLGlob(t *testing.T) {
})
})
go engine.Run()
time.Sleep(200 * time.Millisecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
resp, _ := http.Get("http://127.0.0.1:8893/index")
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
b := make([]byte, 100)
Expand All @@ -182,7 +185,10 @@ func TestLoadHTMLFiles(t *testing.T) {
})
})
go engine.Run()
time.Sleep(200 * time.Millisecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
resp, _ := http.Get("http://127.0.0.1:8891/raw")
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
b := make([]byte, 100)
Expand Down Expand Up @@ -227,7 +233,7 @@ func TestServer_Run(t *testing.T) {
ctx.Redirect(consts.StatusMovedPermanently, []byte("http://127.0.0.1:8899/test"))
})
go hertz.Run()
time.Sleep(100 * time.Microsecond)
time.Sleep(1 * time.Second)
resp, err := http.Get("http://127.0.0.1:8899/test")
assert.Nil(t, err)
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
Expand Down Expand Up @@ -260,7 +266,10 @@ func TestNotAbsolutePath(t *testing.T) {
ctx.Write(ctx.Request.Body())
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)

s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)
Expand Down Expand Up @@ -299,7 +308,10 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)

s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)
Expand Down Expand Up @@ -374,7 +386,10 @@ func TestWithBasePath(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(500 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand All @@ -389,7 +404,10 @@ func TestNotEnoughBodySize(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand All @@ -406,7 +424,10 @@ func TestEnoughBodySize(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand Down
36 changes: 36 additions & 0 deletions pkg/common/utils/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package utils

import (
"os"
"strconv"
"strings"

"github.com/cloudwego/hertz/pkg/common/errors"
)

// Get bool from env
func GetBoolFromEnv(key string) (bool, error) {
value, isExist := os.LookupEnv(key)
if !isExist {
return false, errors.NewPublic("env not exist")
}

value = strings.TrimSpace(value)
return strconv.ParseBool(value)
}
31 changes: 27 additions & 4 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/tracer/stats"
"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand All @@ -41,6 +42,12 @@ import (
"github.com/cloudwego/hertz/pkg/protocol/suite"
)

func init() {
if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil {
disabaleRequestContextPool = b
}
}

// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/1.1's TLS setup.
// Also used for server addressing
Expand All @@ -51,6 +58,8 @@ var (
errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil)
errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")

disabaleRequestContextPool = false
)

type Option struct {
Expand Down Expand Up @@ -80,6 +89,21 @@ type Server struct {
eventStackPool *sync.Pool
}

func (s Server) getRequestContext() *app.RequestContext {
if disabaleRequestContextPool {
return &app.RequestContext{}
}
return s.Core.GetCtxPool().Get().(*app.RequestContext)
}

func (s Server) putRequestContext(ctx *app.RequestContext) {
if disabaleRequestContextPool {
return
}
ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
}

func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
var (
zr network.Reader
Expand All @@ -97,8 +121,8 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
// 1. Get a request context
// 2. Prepare it
// 3. Process it
// 4. Reset and recycle
ctx = s.Core.GetCtxPool().Get().(*app.RequestContext)
// 4. Reset and recycle(in pooled mode)
ctx = s.getRequestContext()

traceCtl = s.Core.GetTracer()
eventsToTrigger *eventStack
Expand Down Expand Up @@ -138,8 +162,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
return
}

ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
s.putRequestContext(ctx)
}()

ctx.HTMLRender = s.HTMLRender
Expand Down
40 changes: 40 additions & 0 deletions pkg/protocol/http1/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ func TestDefaultWriter(t *testing.T) {
assert.DeepEqual(t, "hello, hertz", string(response.Body()))
}

func TestServerDisableReqCtxPool(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if ctx.GetString("POOL_KEY") != "in pool" {
t.Fatal("reqCtx is not in pool")
}
},
isRunning: true,
}
defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err := server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
disabaleRequestContextPool = true
defer func() {
// reset global variable
disabaleRequestContextPool = false
}()
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if len(ctx.GetString("POOL_KEY")) != 0 {
t.Fatal("must not get pool key")
}
},
isRunning: true,
}
defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err = server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
}

func TestHijackResponseWriter(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
Expand Down
10 changes: 5 additions & 5 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,6 @@ func (engine *Engine) Run() (err error) {
return err
}

if err = engine.MarkAsRunning(); err != nil {
return err
}
defer atomic.StoreUint32(&engine.status, statusClosed)

// trigger hooks if any
ctx := context.Background()
for i := range engine.OnRun {
Expand All @@ -362,6 +357,11 @@ func (engine *Engine) Run() (err error) {
}
}

if err = engine.MarkAsRunning(); err != nil {
return err
}
defer atomic.StoreUint32(&engine.status, statusClosed)

return engine.listenAndServe()
}

Expand Down
12 changes: 7 additions & 5 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,20 +875,22 @@ func TestEngineShutdown(t *testing.T) {
defaultTransporter = standard.NewTransporter
mockCtxCallback := func(ctx context.Context) {}
// Test case 1: serve not running error
engine := NewEngine(config.NewOptions(nil))
opt := config.NewOptions(nil)
opt.Addr = "127.0.0.1:10027"
engine := NewEngine(opt)
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
err := engine.Shutdown(ctx1)
assert.DeepEqual(t, errStatusNotRunning, err)

// Test case 2: serve successfully running and shutdown
engine = NewEngine(config.NewOptions(nil))
engine = NewEngine(opt)
engine.OnShutdown = []CtxCallback{mockCtxCallback}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)
time.Sleep(1 * time.Second)

ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
Expand All @@ -897,14 +899,14 @@ func TestEngineShutdown(t *testing.T) {
assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status))

// Test case 3: serve successfully running and shutdown with deregistry error
engine = NewEngine(config.NewOptions(nil))
engine = NewEngine(opt)
engine.OnShutdown = []CtxCallback{mockCtxCallback}
engine.options.Registry = &mockDeregsitryErr{}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)
time.Sleep(1 * time.Second)

ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second)
defer cancel3()
Expand Down

0 comments on commit c29f150

Please sign in to comment.