Skip to content

Commit

Permalink
feat: use one option control both req & resp
Browse files Browse the repository at this point in the history
  • Loading branch information
welkeyever committed Sep 9, 2023
1 parent 935bb29 commit b04f1ac
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 63 deletions.
5 changes: 2 additions & 3 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -788,11 +788,10 @@ func TestSilentMode(t *testing.T) {
func TestHertzDisableHeaderNamesNormalizing(t *testing.T) {
h := New(
WithHostPorts("localhost:9212"),
WithDisableResponseHeaderNamesNormalizing(true),
WithDisableRequestHeaderNamesNormalizing(true),
WithDisableHeaderNamesNormalizing(true),
)
headerName := "CASE-senSITive-HEAder-NAME"
headerValue := "foobar baz"
headerValue := "foobar-baz"
succeed := false
h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
ctx.VisitAllHeaders(func(key, value []byte) {
Expand Down
13 changes: 3 additions & 10 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,9 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte
}}
}

// WithDisableRequestHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableRequestHeaderNamesNormalizing(disable bool) config.Option {
// WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableHeaderNamesNormalizing(disable bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.DisableRequestHeaderNamesNormalizing = disable
}}
}

// WithDisableResponseHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableResponseHeaderNamesNormalizing(disable bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.DisableResponseHeaderNamesNormalizing = disable
o.DisableHeaderNamesNormalizing = disable
}}
}
9 changes: 3 additions & 6 deletions pkg/app/server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ func TestOptions(t *testing.T) {
WithAutoReloadRender(true, 5*time.Second),
WithListenConfig(cfg),
WithAltTransport(transporter),
WithDisableRequestHeaderNamesNormalizing(true),
WithDisableResponseHeaderNamesNormalizing(true),
WithDisableHeaderNamesNormalizing(true),
})
assert.DeepEqual(t, opt.ReadTimeout, time.Second)
assert.DeepEqual(t, opt.WriteTimeout, time.Second)
Expand Down Expand Up @@ -109,8 +108,7 @@ func TestOptions(t *testing.T) {
assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second)
assert.DeepEqual(t, opt.ListenConfig, cfg)
assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter))
assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, true)
assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, true)
assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, true)
}

func TestDefaultOptions(t *testing.T) {
Expand Down Expand Up @@ -143,8 +141,7 @@ func TestDefaultOptions(t *testing.T) {
assert.Assert(t, opt.RegistryInfo == nil)
assert.DeepEqual(t, opt.AutoReloadRender, false)
assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0))
assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, false)
assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, false)
assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, false)
}

type mockTransporter struct{}
Expand Down
8 changes: 2 additions & 6 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ type Options struct {
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableRequestHeaderNamesNormalizing bool

DisableResponseHeaderNamesNormalizing bool
DisableHeaderNamesNormalizing bool
}

func (o *Options) Apply(opts []Option) {
Expand Down Expand Up @@ -245,9 +243,7 @@ func NewOptions(opts []Option) *Options {
Registry: registry.NoopRegistry,

// Disabled header names' normalization, default false
DisableRequestHeaderNamesNormalizing: false,

DisableResponseHeaderNamesNormalizing: false,
DisableHeaderNamesNormalizing: false,
}
options.Apply(opts)
return options
Expand Down
3 changes: 1 addition & 2 deletions pkg/common/config/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ func TestDefaultOptions(t *testing.T) {
assert.DeepEqual(t, []interface{}{}, options.Tracers)
assert.DeepEqual(t, new(interface{}), options.TraceLevel)
assert.DeepEqual(t, registry.NoopRegistry, options.Registry)
assert.DeepEqual(t, false, options.DisableRequestHeaderNamesNormalizing)
assert.DeepEqual(t, false, options.DisableResponseHeaderNamesNormalizing)
assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing)
}

// TestApplyCustomOptions test apply options with custom values after init
Expand Down
2 changes: 2 additions & 0 deletions pkg/protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,7 @@ func (h *RequestHeader) UserAgent() []byte {
// Disable header names' normalization only if you know what are you doing.
func (h *RequestHeader) DisableNormalizing() {
h.disableNormalizing = true
h.Trailer().DisableNormalizing()
}

func (h *RequestHeader) IsDisableNormalizing() bool {
Expand Down Expand Up @@ -1691,6 +1692,7 @@ func (h *RequestHeader) SetMethodBytes(method []byte) {
// Disable header names' normalization only if you know what are you doing.
func (h *ResponseHeader) DisableNormalizing() {
h.disableNormalizing = true
h.Trailer().DisableNormalizing()
}

// setSpecialHeader handles special headers and return true when a header is processed.
Expand Down
34 changes: 20 additions & 14 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,21 @@ var (
)

type Option struct {
StreamRequestBody bool
GetOnly bool
DisablePreParseMultipartForm bool
DisableKeepalive bool
NoDefaultServerHeader bool
MaxRequestBodySize int
IdleTimeout time.Duration
ReadTimeout time.Duration
ServerName []byte
TLS *tls.Config
HTMLRender render.HTMLRender
EnableTrace bool
ContinueHandler func(header *protocol.RequestHeader) bool
HijackConnHandle func(c network.Conn, h app.HijackHandler)
StreamRequestBody bool
GetOnly bool
DisablePreParseMultipartForm bool
DisableKeepalive bool
NoDefaultServerHeader bool
DisableHeaderNamesNormalizing bool
MaxRequestBodySize int
IdleTimeout time.Duration
ReadTimeout time.Duration
ServerName []byte
TLS *tls.Config
HTMLRender render.HTMLRender
EnableTrace bool
ContinueHandler func(header *protocol.RequestHeader) bool
HijackConnHandle func(c network.Conn, h app.HijackHandler)
}

type Server struct {
Expand Down Expand Up @@ -180,6 +181,11 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
})
}

if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}

// Read Headers
if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil {
if s.EnableTrace {
Expand Down
37 changes: 15 additions & 22 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,6 @@ func (engine *Engine) allocateContext() *app.RequestContext {
ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize)
ctx.SetClientIPFunc(engine.clientIPFunc)
ctx.SetFormValueFunc(engine.formValueFunc)
if engine.options.DisableRequestHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Request.Header.Trailer().DisableNormalizing()
}
if engine.options.DisableResponseHeaderNamesNormalizing {
ctx.Response.Header.DisableNormalizing()
ctx.Response.Header.Trailer().DisableNormalizing()
}
return ctx
}

Expand Down Expand Up @@ -998,20 +990,21 @@ func iterate(method string, routes RoutesInfo, root *node) RoutesInfo {
// for built-in http1 impl only.
func newHttp1OptionFromEngine(engine *Engine) *http1.Option {
opt := &http1.Option{
StreamRequestBody: engine.options.StreamRequestBody,
GetOnly: engine.options.GetOnly,
DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm,
DisableKeepalive: engine.options.DisableKeepalive,
NoDefaultServerHeader: engine.options.NoDefaultServerHeader,
MaxRequestBodySize: engine.options.MaxRequestBodySize,
IdleTimeout: engine.options.IdleTimeout,
ReadTimeout: engine.options.ReadTimeout,
ServerName: engine.GetServerName(),
ContinueHandler: engine.ContinueHandler,
TLS: engine.options.TLS,
HTMLRender: engine.htmlRender,
EnableTrace: engine.IsTraceEnable(),
HijackConnHandle: engine.HijackConnHandle,
StreamRequestBody: engine.options.StreamRequestBody,
GetOnly: engine.options.GetOnly,
DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm,
DisableKeepalive: engine.options.DisableKeepalive,
NoDefaultServerHeader: engine.options.NoDefaultServerHeader,
MaxRequestBodySize: engine.options.MaxRequestBodySize,
IdleTimeout: engine.options.IdleTimeout,
ReadTimeout: engine.options.ReadTimeout,
ServerName: engine.GetServerName(),
ContinueHandler: engine.ContinueHandler,
TLS: engine.options.TLS,
HTMLRender: engine.htmlRender,
EnableTrace: engine.IsTraceEnable(),
HijackConnHandle: engine.HijackConnHandle,
DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing,
}
// Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero.
// Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons.
Expand Down

0 comments on commit b04f1ac

Please sign in to comment.