diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index cfb0c3a8f..a5cf7d350 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -788,11 +788,10 @@ func TestSilentMode(t *testing.T) { func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { h := New( WithHostPorts("localhost:9212"), - WithDisableResponseHeaderNamesNormalizing(true), - WithDisableRequestHeaderNamesNormalizing(true), + WithDisableHeaderNamesNormalizing(true), ) headerName := "CASE-senSITive-HEAder-NAME" - headerValue := "foobar baz" + headerValue := "foobar-baz" succeed := false h.GET("/test", func(c context.Context, ctx *app.RequestContext) { ctx.VisitAllHeaders(func(key, value []byte) { diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index 052caac7f..c9e3735be 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -347,16 +347,9 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte }} } -// WithDisableRequestHeaderNamesNormalizing is used to set whether disable header names normalizing. -func WithDisableRequestHeaderNamesNormalizing(disable bool) config.Option { +// WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. +func WithDisableHeaderNamesNormalizing(disable bool) config.Option { return config.Option{F: func(o *config.Options) { - o.DisableRequestHeaderNamesNormalizing = disable - }} -} - -// WithDisableResponseHeaderNamesNormalizing is used to set whether disable header names normalizing. -func WithDisableResponseHeaderNamesNormalizing(disable bool) config.Option { - return config.Option{F: func(o *config.Options) { - o.DisableResponseHeaderNamesNormalizing = disable + o.DisableHeaderNamesNormalizing = disable }} } diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index c66efd843..f5d7f7b32 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -75,8 +75,7 @@ func TestOptions(t *testing.T) { WithAutoReloadRender(true, 5*time.Second), WithListenConfig(cfg), WithAltTransport(transporter), - WithDisableRequestHeaderNamesNormalizing(true), - WithDisableResponseHeaderNamesNormalizing(true), + WithDisableHeaderNamesNormalizing(true), }) assert.DeepEqual(t, opt.ReadTimeout, time.Second) assert.DeepEqual(t, opt.WriteTimeout, time.Second) @@ -109,8 +108,7 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second) assert.DeepEqual(t, opt.ListenConfig, cfg) assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter)) - assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, true) - assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, true) + assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, true) } func TestDefaultOptions(t *testing.T) { @@ -143,8 +141,7 @@ func TestDefaultOptions(t *testing.T) { assert.Assert(t, opt.RegistryInfo == nil) assert.DeepEqual(t, opt.AutoReloadRender, false) assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0)) - assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, false) - assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, false) + assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, false) } type mockTransporter struct{} diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index b5454fa47..9ef7ddf42 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -112,9 +112,7 @@ type Options struct { // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length - DisableRequestHeaderNamesNormalizing bool - - DisableResponseHeaderNamesNormalizing bool + DisableHeaderNamesNormalizing bool } func (o *Options) Apply(opts []Option) { @@ -245,9 +243,7 @@ func NewOptions(opts []Option) *Options { Registry: registry.NoopRegistry, // Disabled header names' normalization, default false - DisableRequestHeaderNamesNormalizing: false, - - DisableResponseHeaderNamesNormalizing: false, + DisableHeaderNamesNormalizing: false, } options.Apply(opts) return options diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 49991cf1a..39d92d736 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -53,8 +53,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, []interface{}{}, options.Tracers) assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) - assert.DeepEqual(t, false, options.DisableRequestHeaderNamesNormalizing) - assert.DeepEqual(t, false, options.DisableResponseHeaderNamesNormalizing) + assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } // TestApplyCustomOptions test apply options with custom values after init diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index 3ac88016b..f78de96f7 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -1491,6 +1491,7 @@ func (h *RequestHeader) UserAgent() []byte { // Disable header names' normalization only if you know what are you doing. func (h *RequestHeader) DisableNormalizing() { h.disableNormalizing = true + h.Trailer().DisableNormalizing() } func (h *RequestHeader) IsDisableNormalizing() bool { @@ -1691,6 +1692,7 @@ func (h *RequestHeader) SetMethodBytes(method []byte) { // Disable header names' normalization only if you know what are you doing. func (h *ResponseHeader) DisableNormalizing() { h.disableNormalizing = true + h.Trailer().DisableNormalizing() } // setSpecialHeader handles special headers and return true when a header is processed. diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index a71b972e0..66cb01e85 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -54,20 +54,21 @@ var ( ) type Option struct { - StreamRequestBody bool - GetOnly bool - DisablePreParseMultipartForm bool - DisableKeepalive bool - NoDefaultServerHeader bool - MaxRequestBodySize int - IdleTimeout time.Duration - ReadTimeout time.Duration - ServerName []byte - TLS *tls.Config - HTMLRender render.HTMLRender - EnableTrace bool - ContinueHandler func(header *protocol.RequestHeader) bool - HijackConnHandle func(c network.Conn, h app.HijackHandler) + StreamRequestBody bool + GetOnly bool + DisablePreParseMultipartForm bool + DisableKeepalive bool + NoDefaultServerHeader bool + DisableHeaderNamesNormalizing bool + MaxRequestBodySize int + IdleTimeout time.Duration + ReadTimeout time.Duration + ServerName []byte + TLS *tls.Config + HTMLRender render.HTMLRender + EnableTrace bool + ContinueHandler func(header *protocol.RequestHeader) bool + HijackConnHandle func(c network.Conn, h app.HijackHandler) } type Server struct { @@ -180,6 +181,11 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { }) } + if s.DisableHeaderNamesNormalizing { + ctx.Request.Header.DisableNormalizing() + ctx.Response.Header.DisableNormalizing() + } + // Read Headers if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil { if s.EnableTrace { diff --git a/pkg/route/engine.go b/pkg/route/engine.go index aeda51bb6..627406e84 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -737,14 +737,6 @@ func (engine *Engine) allocateContext() *app.RequestContext { ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) ctx.SetClientIPFunc(engine.clientIPFunc) ctx.SetFormValueFunc(engine.formValueFunc) - if engine.options.DisableRequestHeaderNamesNormalizing { - ctx.Request.Header.DisableNormalizing() - ctx.Request.Header.Trailer().DisableNormalizing() - } - if engine.options.DisableResponseHeaderNamesNormalizing { - ctx.Response.Header.DisableNormalizing() - ctx.Response.Header.Trailer().DisableNormalizing() - } return ctx } @@ -998,20 +990,21 @@ func iterate(method string, routes RoutesInfo, root *node) RoutesInfo { // for built-in http1 impl only. func newHttp1OptionFromEngine(engine *Engine) *http1.Option { opt := &http1.Option{ - StreamRequestBody: engine.options.StreamRequestBody, - GetOnly: engine.options.GetOnly, - DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, - DisableKeepalive: engine.options.DisableKeepalive, - NoDefaultServerHeader: engine.options.NoDefaultServerHeader, - MaxRequestBodySize: engine.options.MaxRequestBodySize, - IdleTimeout: engine.options.IdleTimeout, - ReadTimeout: engine.options.ReadTimeout, - ServerName: engine.GetServerName(), - ContinueHandler: engine.ContinueHandler, - TLS: engine.options.TLS, - HTMLRender: engine.htmlRender, - EnableTrace: engine.IsTraceEnable(), - HijackConnHandle: engine.HijackConnHandle, + StreamRequestBody: engine.options.StreamRequestBody, + GetOnly: engine.options.GetOnly, + DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, + DisableKeepalive: engine.options.DisableKeepalive, + NoDefaultServerHeader: engine.options.NoDefaultServerHeader, + MaxRequestBodySize: engine.options.MaxRequestBodySize, + IdleTimeout: engine.options.IdleTimeout, + ReadTimeout: engine.options.ReadTimeout, + ServerName: engine.GetServerName(), + ContinueHandler: engine.ContinueHandler, + TLS: engine.options.TLS, + HTMLRender: engine.htmlRender, + EnableTrace: engine.IsTraceEnable(), + HijackConnHandle: engine.HijackConnHandle, + DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing, } // Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero. // Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons.