diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 5f3e8c62f..ed5959bea 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -92,7 +92,12 @@ func TestCloseIdleConnections(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -136,7 +141,12 @@ func TestClientInvalidURI(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -170,7 +180,12 @@ func TestClientGetWithBody(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -205,7 +220,12 @@ func TestClientPostBodyStream(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } cStream, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) args := &protocol.Args{} @@ -246,7 +266,12 @@ func TestClientURLAuth(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) for up, expected := range cases { @@ -278,7 +303,12 @@ func TestClientNilResp(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -301,7 +331,15 @@ func TestClientParseConn(t *testing.T) { engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(time.Millisecond * 500) + defer func() { + engine.Close() + }() + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -343,7 +381,12 @@ func TestClientPostArgs(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { @@ -403,20 +446,22 @@ func TestClientReadTimeout(t *testing.T) { t.Skip("skipping test in short mode") } - timeout := false opt := config.NewOptions([]config.Option{}) - opt.Addr = "localhost:10008" + opt.Addr = "localhost:10024" engine := route.NewEngine(opt) - engine.GET("/", func(c context.Context, ctx *app.RequestContext) { - if timeout { - time.Sleep(time.Minute) - } else { - timeout = true - } + engine.GET("/normal", func(c context.Context, ctx *app.RequestContext) { + ctx.String(201, "ok") + }) + engine.GET("/timeout", func(c context.Context, ctx *app.RequestContext) { + time.Sleep(time.Second * 60) + ctx.String(202, "timeout ok") }) go engine.Run() - time.Sleep(time.Millisecond * 500) + defer func() { + engine.Close() + }() + time.Sleep(time.Second * 1) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -430,7 +475,7 @@ func TestClientReadTimeout(t *testing.T) { req := protocol.AcquireRequest() res := protocol.AcquireResponse() - req.SetRequestURI("http://" + opt.Addr) + req.SetRequestURI("http://" + opt.Addr + "/normal") req.Header.SetMethod(consts.MethodGet) // Setting Connection: Close will make the connection be returned to the pool. @@ -448,13 +493,17 @@ func TestClientReadTimeout(t *testing.T) { req := protocol.AcquireRequest() res := protocol.AcquireResponse() - req.SetRequestURI("http://" + opt.Addr) + req.SetRequestURI("http://" + opt.Addr + "/timeout") req.Header.SetMethod(consts.MethodGet) req.SetConnectionClose() if err := c.Do(context.Background(), req, res); !errors.Is(err, errs.ErrTimeout) { - if !strings.Contains(err.Error(), "timeout") { - t.Errorf("expected ErrTimeout got %#v", err) + if err == nil { + t.Errorf("expected ErrTimeout got nil, req url: %s, read resp body: %s, status: %d", string(req.URI().FullURI()), string(res.Body()), res.StatusCode()) + } else { + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("expected ErrTimeout got %#v", err) + } } } @@ -488,7 +537,12 @@ func TestClientDefaultUserAgent(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req := protocol.AcquireRequest() @@ -522,7 +576,12 @@ func TestClientSetUserAgent(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } userAgent := "I'm not hertz" c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithName(userAgent)) @@ -553,7 +612,12 @@ func TestClientNoUserAgent(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDialTimeout(1*time.Second), WithNoDefaultUserAgentHeader(true)) req := protocol.AcquireRequest() @@ -634,7 +698,12 @@ func TestClientDoWithCustomHeaders(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } // make sure that the client sends all the request headers and body. c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -679,7 +748,12 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisablePathNormalizing(true)) urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff" @@ -809,7 +883,12 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -878,7 +957,12 @@ func TestHostClientMaxConnDuration(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -923,7 +1007,12 @@ func TestHostClientMultipleAddrs(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } dialsCount := make(map[string]int) c := &http1.HostClient{ @@ -990,7 +1079,7 @@ func TestClientFollowRedirects(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + time.Sleep(time.Second * 2) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1084,7 +1173,12 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1153,7 +1247,12 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + for { + time.Sleep(1 * time.Second) + if engine.IsRunning() { + break + } + } c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1215,7 +1314,10 @@ func TestNewClient(t *testing.T) { ctx.SetBodyString("pong") }) go engine.Run() - time.Sleep(time.Millisecond * 500) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) client, err := NewClient(WithDialTimeout(2 * time.Second)) if err != nil { @@ -1240,7 +1342,10 @@ func TestUseShortConnection(t *testing.T) { engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(time.Millisecond * 500) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) c, _ := NewClient(WithKeepAlive(false)) var wg sync.WaitGroup @@ -1284,8 +1389,11 @@ func TestPostWithFormData(t *testing.T) { ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", []byte(ans)) }) go engine.Run() + defer func() { + engine.Close() + }() - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1335,8 +1443,11 @@ func TestPostWithMultipartField(t *testing.T) { t.Log(req.GetHTTP1Request(&ctx.Request).String()) }) go engine.Run() + defer func() { + engine.Close() + }() - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1378,8 +1489,11 @@ func TestSetFiles(t *testing.T) { ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", len(files)+2)) }) go engine.Run() + defer func() { + engine.Close() + }() - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1426,8 +1540,11 @@ func TestSetMultipartFields(t *testing.T) { ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", 2)) }) go engine.Run() + defer func() { + engine.Close() + }() - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) client, _ := NewClient(WithDialTimeout(50 * time.Millisecond)) req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1478,7 +1595,10 @@ func TestClientReadResponseBodyStream(t *testing.T) { c.String(consts.StatusOK, part1+part2) }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -1528,7 +1648,10 @@ func TestWithBasicAuth(t *testing.T) { } }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1896,7 +2019,10 @@ func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { c.String(consts.StatusOK, part1+part2) }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -1966,7 +2092,10 @@ func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { c.String(consts.StatusOK, part1) }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) client, _ := NewClient(WithResponseBodyStream(true)) @@ -2276,7 +2405,7 @@ func TestClientDoWithDialFunc(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Millisecond * 500) + time.Sleep(1 * time.Second) c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) { return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil) @@ -2307,11 +2436,14 @@ func TestClientDoWithDialFunc(t *testing.T) { func TestClientState(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = ":10037" + opt.Addr = "127.0.0.1:10037" engine := route.NewEngine(opt) go engine.Run() + defer func() { + engine.Close() + }() - time.Sleep(time.Millisecond) + time.Sleep(1 * time.Second) state := int32(0) client, _ := NewClient( @@ -2351,14 +2483,16 @@ func TestClientRetryErr(t *testing.T) { ctx.SetStatusCode(200) }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10136/ping") assert.Nil(t, err) l.Lock() assert.DeepEqual(t, 1, retryNum) l.Unlock() - engine.Close() }) t.Run("502", func(t *testing.T) { @@ -2374,7 +2508,10 @@ func TestClientRetryErr(t *testing.T) { ctx.SetStatusCode(502) }) go engine.Run() - time.Sleep(100 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return resp.StatusCode() == 502 @@ -2384,6 +2521,5 @@ func TestClientRetryErr(t *testing.T) { l.Lock() assert.DeepEqual(t, 3, retryNum) l.Unlock() - engine.Close() }) } diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index 8b53c2bf5..fb598239a 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -87,7 +87,7 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { return &interfaceDecoder{}, nil } - return nil, fmt.Errorf("unsupported type " + rt.String()) + return nil, fmt.Errorf("unsupported type %s", rt.String()) } type boolDecoder struct{} diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index d20db2f96..09e77dd40 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -158,7 +158,10 @@ func TestLoadHTMLGlob(t *testing.T) { }) }) go engine.Run() - time.Sleep(200 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) resp, _ := http.Get("http://127.0.0.1:8893/index") assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) @@ -182,7 +185,10 @@ func TestLoadHTMLFiles(t *testing.T) { }) }) go engine.Run() - time.Sleep(200 * time.Millisecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) resp, _ := http.Get("http://127.0.0.1:8891/raw") assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) @@ -227,7 +233,7 @@ func TestServer_Run(t *testing.T) { ctx.Redirect(consts.StatusMovedPermanently, []byte("http://127.0.0.1:8899/test")) }) go hertz.Run() - time.Sleep(100 * time.Microsecond) + time.Sleep(1 * time.Second) resp, err := http.Get("http://127.0.0.1:8899/test") assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) @@ -260,7 +266,10 @@ func TestNotAbsolutePath(t *testing.T) { ctx.Write(ctx.Request.Body()) }) go engine.Run() - time.Sleep(200 * time.Microsecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) @@ -299,7 +308,10 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(200 * time.Microsecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) @@ -374,7 +386,10 @@ func TestWithBasePath(t *testing.T) { engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(500 * time.Microsecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") @@ -389,7 +404,10 @@ func TestNotEnoughBodySize(t *testing.T) { engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(200 * time.Microsecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") @@ -406,7 +424,10 @@ func TestEnoughBodySize(t *testing.T) { engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() - time.Sleep(200 * time.Microsecond) + defer func() { + engine.Close() + }() + time.Sleep(1 * time.Second) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") diff --git a/pkg/common/utils/env.go b/pkg/common/utils/env.go new file mode 100644 index 000000000..291b5e3bd --- /dev/null +++ b/pkg/common/utils/env.go @@ -0,0 +1,36 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "os" + "strconv" + "strings" + + "github.com/cloudwego/hertz/pkg/common/errors" +) + +// Get bool from env +func GetBoolFromEnv(key string) (bool, error) { + value, isExist := os.LookupEnv(key) + if !isExist { + return false, errors.NewPublic("env not exist") + } + + value = strings.TrimSpace(value) + return strconv.ParseBool(value) +} diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 3ea659603..eb46f6533 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -32,6 +32,7 @@ import ( errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -41,6 +42,12 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/suite" ) +func init() { + if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil { + disabaleRequestContextPool = b + } +} + // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/1.1's TLS setup. // Also used for server addressing @@ -51,6 +58,8 @@ var ( errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil) errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") + + disabaleRequestContextPool = false ) type Option struct { @@ -80,6 +89,21 @@ type Server struct { eventStackPool *sync.Pool } +func (s Server) getRequestContext() *app.RequestContext { + if disabaleRequestContextPool { + return &app.RequestContext{} + } + return s.Core.GetCtxPool().Get().(*app.RequestContext) +} + +func (s Server) putRequestContext(ctx *app.RequestContext) { + if disabaleRequestContextPool { + return + } + ctx.Reset() + s.Core.GetCtxPool().Put(ctx) +} + func (s Server) Serve(c context.Context, conn network.Conn) (err error) { var ( zr network.Reader @@ -97,8 +121,8 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { // 1. Get a request context // 2. Prepare it // 3. Process it - // 4. Reset and recycle - ctx = s.Core.GetCtxPool().Get().(*app.RequestContext) + // 4. Reset and recycle(in pooled mode) + ctx = s.getRequestContext() traceCtl = s.Core.GetTracer() eventsToTrigger *eventStack @@ -138,8 +162,7 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { return } - ctx.Reset() - s.Core.GetCtxPool().Put(ctx) + s.putRequestContext(ctx) }() ctx.HTMLRender = s.HTMLRender diff --git a/pkg/protocol/http1/server_test.go b/pkg/protocol/http1/server_test.go index 2263ece77..d478b36fc 100644 --- a/pkg/protocol/http1/server_test.go +++ b/pkg/protocol/http1/server_test.go @@ -218,6 +218,46 @@ func TestDefaultWriter(t *testing.T) { assert.DeepEqual(t, "hello, hertz", string(response.Body())) } +func TestServerDisableReqCtxPool(t *testing.T) { + server := &Server{} + reqCtx := &app.RequestContext{} + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + reqCtx.Set("POOL_KEY", "in pool") + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + if ctx.GetString("POOL_KEY") != "in pool" { + t.Fatal("reqCtx is not in pool") + } + }, + isRunning: true, + } + defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") + err := server.Serve(context.TODO(), defaultConn) + assert.Nil(t, err) + disabaleRequestContextPool = true + defer func() { + // reset global variable + disabaleRequestContextPool = false + }() + server.Core = &mockCore{ + ctxPool: &sync.Pool{New: func() interface{} { + reqCtx.Set("POOL_KEY", "in pool") + return reqCtx + }}, + mockHandler: func(c context.Context, ctx *app.RequestContext) { + if len(ctx.GetString("POOL_KEY")) != 0 { + t.Fatal("must not get pool key") + } + }, + isRunning: true, + } + defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") + err = server.Serve(context.TODO(), defaultConn) + assert.Nil(t, err) +} + func TestHijackResponseWriter(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 4881ee9cf..2a7b60ca6 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -349,11 +349,6 @@ func (engine *Engine) Run() (err error) { return err } - if err = engine.MarkAsRunning(); err != nil { - return err - } - defer atomic.StoreUint32(&engine.status, statusClosed) - // trigger hooks if any ctx := context.Background() for i := range engine.OnRun { @@ -362,6 +357,11 @@ func (engine *Engine) Run() (err error) { } } + if err = engine.MarkAsRunning(); err != nil { + return err + } + defer atomic.StoreUint32(&engine.status, statusClosed) + return engine.listenAndServe() } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index ea1bc5fd9..a5da5dc56 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -875,20 +875,22 @@ func TestEngineShutdown(t *testing.T) { defaultTransporter = standard.NewTransporter mockCtxCallback := func(ctx context.Context) {} // Test case 1: serve not running error - engine := NewEngine(config.NewOptions(nil)) + opt := config.NewOptions(nil) + opt.Addr = "127.0.0.1:10027" + engine := NewEngine(opt) ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) defer cancel1() err := engine.Shutdown(ctx1) assert.DeepEqual(t, errStatusNotRunning, err) // Test case 2: serve successfully running and shutdown - engine = NewEngine(config.NewOptions(nil)) + engine = NewEngine(opt) engine.OnShutdown = []CtxCallback{mockCtxCallback} go func() { engine.Run() }() // wait for engine to start - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) defer cancel2() @@ -897,14 +899,14 @@ func TestEngineShutdown(t *testing.T) { assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status)) // Test case 3: serve successfully running and shutdown with deregistry error - engine = NewEngine(config.NewOptions(nil)) + engine = NewEngine(opt) engine.OnShutdown = []CtxCallback{mockCtxCallback} engine.options.Registry = &mockDeregsitryErr{} go func() { engine.Run() }() // wait for engine to start - time.Sleep(100 * time.Millisecond) + time.Sleep(1 * time.Second) ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) defer cancel3()