Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(http1): add env to disable request context pool #1173

Merged
merged 12 commits into from
Aug 22, 2024
234 changes: 185 additions & 49 deletions pkg/app/client/client_test.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/text_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) {
return &interfaceDecoder{}, nil
}

return nil, fmt.Errorf("unsupported type " + rt.String())
return nil, fmt.Errorf("unsupported type %s", rt.String())
}

type boolDecoder struct{}
Expand Down
37 changes: 29 additions & 8 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ func TestLoadHTMLGlob(t *testing.T) {
})
})
go engine.Run()
time.Sleep(200 * time.Millisecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
resp, _ := http.Get("http://127.0.0.1:8893/index")
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
b := make([]byte, 100)
Expand All @@ -182,7 +185,10 @@ func TestLoadHTMLFiles(t *testing.T) {
})
})
go engine.Run()
time.Sleep(200 * time.Millisecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
resp, _ := http.Get("http://127.0.0.1:8891/raw")
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
b := make([]byte, 100)
Expand Down Expand Up @@ -227,7 +233,7 @@ func TestServer_Run(t *testing.T) {
ctx.Redirect(consts.StatusMovedPermanently, []byte("http://127.0.0.1:8899/test"))
})
go hertz.Run()
time.Sleep(100 * time.Microsecond)
time.Sleep(1 * time.Second)
resp, err := http.Get("http://127.0.0.1:8899/test")
assert.Nil(t, err)
assert.DeepEqual(t, consts.StatusOK, resp.StatusCode)
Expand Down Expand Up @@ -260,7 +266,10 @@ func TestNotAbsolutePath(t *testing.T) {
ctx.Write(ctx.Request.Body())
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)

s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)
Expand Down Expand Up @@ -299,7 +308,10 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) {
engine.POST("/a", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)

s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343"
zr := mock.NewZeroCopyReader(s)
Expand Down Expand Up @@ -374,7 +386,10 @@ func TestWithBasePath(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(500 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand All @@ -389,7 +404,10 @@ func TestNotEnoughBodySize(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand All @@ -406,7 +424,10 @@ func TestEnoughBodySize(t *testing.T) {
engine.POST("/test", func(c context.Context, ctx *app.RequestContext) {
})
go engine.Run()
time.Sleep(200 * time.Microsecond)
defer func() {
engine.Close()
}()
time.Sleep(1 * time.Second)
var r http.Request
r.ParseForm()
r.Form.Add("xxxxxx", "xxx")
Expand Down
36 changes: 36 additions & 0 deletions pkg/common/utils/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package utils

import (
"os"
"strconv"
"strings"

"github.com/cloudwego/hertz/pkg/common/errors"
)

// Get bool from env
func GetBoolFromEnv(key string) (bool, error) {
value, isExist := os.LookupEnv(key)
if !isExist {
return false, errors.NewPublic("env not exist")
}

value = strings.TrimSpace(value)
return strconv.ParseBool(value)
}
31 changes: 27 additions & 4 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/tracer/stats"
"github.com/cloudwego/hertz/pkg/common/tracer/traceinfo"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
Expand All @@ -41,6 +42,12 @@ import (
"github.com/cloudwego/hertz/pkg/protocol/suite"
)

func init() {
if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil {
disabaleRequestContextPool = b
}
}

// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/1.1's TLS setup.
// Also used for server addressing
Expand All @@ -51,6 +58,8 @@ var (
errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil)
errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection")
errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request")

disabaleRequestContextPool = false
)

type Option struct {
Expand Down Expand Up @@ -80,6 +89,21 @@ type Server struct {
eventStackPool *sync.Pool
}

func (s Server) getRequestContext() *app.RequestContext {
if disabaleRequestContextPool {
return &app.RequestContext{}
}
return s.Core.GetCtxPool().Get().(*app.RequestContext)
}

func (s Server) putRequestContext(ctx *app.RequestContext) {
if disabaleRequestContextPool {
return
}
ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
}

func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
var (
zr network.Reader
Expand All @@ -97,8 +121,8 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
// 1. Get a request context
// 2. Prepare it
// 3. Process it
// 4. Reset and recycle
ctx = s.Core.GetCtxPool().Get().(*app.RequestContext)
// 4. Reset and recycle(in pooled mode)
ctx = s.getRequestContext()

traceCtl = s.Core.GetTracer()
eventsToTrigger *eventStack
Expand Down Expand Up @@ -138,8 +162,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
return
}

ctx.Reset()
s.Core.GetCtxPool().Put(ctx)
s.putRequestContext(ctx)
}()

ctx.HTMLRender = s.HTMLRender
Expand Down
40 changes: 40 additions & 0 deletions pkg/protocol/http1/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,46 @@ func TestDefaultWriter(t *testing.T) {
assert.DeepEqual(t, "hello, hertz", string(response.Body()))
}

func TestServerDisableReqCtxPool(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if ctx.GetString("POOL_KEY") != "in pool" {
t.Fatal("reqCtx is not in pool")
}
},
isRunning: true,
}
defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err := server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
disabaleRequestContextPool = true
defer func() {
// reset global variable
disabaleRequestContextPool = false
}()
server.Core = &mockCore{
ctxPool: &sync.Pool{New: func() interface{} {
reqCtx.Set("POOL_KEY", "in pool")
return reqCtx
}},
mockHandler: func(c context.Context, ctx *app.RequestContext) {
if len(ctx.GetString("POOL_KEY")) != 0 {
t.Fatal("must not get pool key")
}
},
isRunning: true,
}
defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n")
err = server.Serve(context.TODO(), defaultConn)
assert.Nil(t, err)
}

func TestHijackResponseWriter(t *testing.T) {
server := &Server{}
reqCtx := &app.RequestContext{}
Expand Down
10 changes: 5 additions & 5 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,6 @@ func (engine *Engine) Run() (err error) {
return err
}

if err = engine.MarkAsRunning(); err != nil {
return err
}
defer atomic.StoreUint32(&engine.status, statusClosed)

// trigger hooks if any
ctx := context.Background()
for i := range engine.OnRun {
Expand All @@ -362,6 +357,11 @@ func (engine *Engine) Run() (err error) {
}
}

if err = engine.MarkAsRunning(); err != nil {
return err
}
defer atomic.StoreUint32(&engine.status, statusClosed)

return engine.listenAndServe()
}

Expand Down
12 changes: 7 additions & 5 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -875,20 +875,22 @@ func TestEngineShutdown(t *testing.T) {
defaultTransporter = standard.NewTransporter
mockCtxCallback := func(ctx context.Context) {}
// Test case 1: serve not running error
engine := NewEngine(config.NewOptions(nil))
opt := config.NewOptions(nil)
opt.Addr = "127.0.0.1:10027"
engine := NewEngine(opt)
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
err := engine.Shutdown(ctx1)
assert.DeepEqual(t, errStatusNotRunning, err)

// Test case 2: serve successfully running and shutdown
engine = NewEngine(config.NewOptions(nil))
engine = NewEngine(opt)
engine.OnShutdown = []CtxCallback{mockCtxCallback}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)
time.Sleep(1 * time.Second)

ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
Expand All @@ -897,14 +899,14 @@ func TestEngineShutdown(t *testing.T) {
assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status))

// Test case 3: serve successfully running and shutdown with deregistry error
engine = NewEngine(config.NewOptions(nil))
engine = NewEngine(opt)
engine.OnShutdown = []CtxCallback{mockCtxCallback}
engine.options.Registry = &mockDeregsitryErr{}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)
time.Sleep(1 * time.Second)

ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second)
defer cancel3()
Expand Down
Loading