From 764df4d7fde9f43f422876ead37a36c43bc7c29e Mon Sep 17 00:00:00 2001 From: Jingyang Kang Date: Wed, 17 Jul 2024 11:05:45 +0800 Subject: [PATCH] fix: hertz panic when edit ctx.Params on HandleFunc (#1150) --- pkg/route/engine.go | 7 +++++++ pkg/route/engine_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 30b899dac..4881ee9cf 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -75,6 +75,7 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/http1/factory" "github.com/cloudwego/hertz/pkg/protocol/suite" + "github.com/cloudwego/hertz/pkg/route/param" ) const unknownTransporterName = "unknown" @@ -749,6 +750,12 @@ func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { return } + // if Params is re-assigned in HandlerFunc and the capacity is not enough we need to realloc + maxParams := int(engine.maxParams) + if cap(ctx.Params) < maxParams { + ctx.Params = make(param.Params, 0, maxParams) + } + // Find root of the tree for the given HTTP method t := engine.trees paramsPointer := &ctx.Params diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index b3e0adb30..ea1bc5fd9 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -1029,3 +1029,33 @@ func TestAcquireHijackConn(t *testing.T) { assert.DeepEqual(t, engine, hijackConn.e) assert.DeepEqual(t, conn, hijackConn.Conn) } + +func TestHandleParamsReassignInHandleFunc(t *testing.T) { + e := NewEngine(config.NewOptions(nil)) + routes := []string{ + "/:a/:b/:c", + } + for _, r := range routes { + e.GET(r, func(c context.Context, ctx *app.RequestContext) { + ctx.Params = make([]param.Param, 1) + ctx.String(consts.StatusOK, "") + }) + } + testRoutes := []string{ + "/aaa/bbb/ccc", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/asd/alskja/alkdjad", + "/alksjdlakjd/ooo/askda", + "/alksjdlakjd/ooo/askda", + "/alksjdlakjd/ooo/askda", + } + ctx := e.ctxPool.Get().(*app.RequestContext) + for _, tr := range testRoutes { + r := protocol.NewRequest(http.MethodGet, tr, nil) + r.CopyTo(&ctx.Request) + e.ServeHTTP(context.Background(), ctx) + ctx.ResetWithoutConn() + } +}