Skip to content

Commit

Permalink
feat: distinguishing request and response
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou authored and welkeyever committed Sep 9, 2023
1 parent 404144e commit 935bb29
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 41 deletions.
37 changes: 37 additions & 0 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,40 @@ func TestSilentMode(t *testing.T) {
t.Fatalf("unexpected error in log: %s", b.String())
}
}

func TestHertzDisableHeaderNamesNormalizing(t *testing.T) {
h := New(
WithHostPorts("localhost:9212"),
WithDisableResponseHeaderNamesNormalizing(true),
WithDisableRequestHeaderNamesNormalizing(true),
)
headerName := "CASE-senSITive-HEAder-NAME"
headerValue := "foobar baz"
succeed := false
h.GET("/test", func(c context.Context, ctx *app.RequestContext) {
ctx.VisitAllHeaders(func(key, value []byte) {
if string(key) == headerName && string(value) == headerValue {
succeed = true
return
}
})
if !succeed {
t.Fatalf("DisableHeaderNamesNormalizing failed")
} else {
ctx.Header(headerName, headerValue)
}
})

go h.Spin()
time.Sleep(100 * time.Millisecond)

cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true))

r := protocol.NewRequest("GET", "http://localhost:9212/test", nil)
r.Header.DisableNormalizing()
r.Header.Set(headerName, headerValue)
res := protocol.AcquireResponse()
err := cli.Do(context.Background(), r, res)
assert.Nil(t, err)
assert.DeepEqual(t, headerValue, res.Header.Get(headerName))
}
13 changes: 10 additions & 3 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,16 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte
}}
}

// WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableHeaderNamesNormalizing(disable bool) config.Option {
// WithDisableRequestHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableRequestHeaderNamesNormalizing(disable bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.DisableHeaderNamesNormalizing = disable
o.DisableRequestHeaderNamesNormalizing = disable
}}
}

// WithDisableResponseHeaderNamesNormalizing is used to set whether disable header names normalizing.
func WithDisableResponseHeaderNamesNormalizing(disable bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.DisableResponseHeaderNamesNormalizing = disable
}}
}
6 changes: 6 additions & 0 deletions pkg/app/server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func TestOptions(t *testing.T) {
WithAutoReloadRender(true, 5*time.Second),
WithListenConfig(cfg),
WithAltTransport(transporter),
WithDisableRequestHeaderNamesNormalizing(true),
WithDisableResponseHeaderNamesNormalizing(true),
})
assert.DeepEqual(t, opt.ReadTimeout, time.Second)
assert.DeepEqual(t, opt.WriteTimeout, time.Second)
Expand Down Expand Up @@ -107,6 +109,8 @@ func TestOptions(t *testing.T) {
assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second)
assert.DeepEqual(t, opt.ListenConfig, cfg)
assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter))
assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, true)
assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, true)
}

func TestDefaultOptions(t *testing.T) {
Expand Down Expand Up @@ -139,6 +143,8 @@ func TestDefaultOptions(t *testing.T) {
assert.Assert(t, opt.RegistryInfo == nil)
assert.DeepEqual(t, opt.AutoReloadRender, false)
assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0))
assert.DeepEqual(t, opt.DisableRequestHeaderNamesNormalizing, false)
assert.DeepEqual(t, opt.DisableResponseHeaderNamesNormalizing, false)
}

type mockTransporter struct{}
Expand Down
10 changes: 7 additions & 3 deletions pkg/common/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,17 @@ type Options struct {
// Disabled header names' normalization may be useful only for proxying
// responses to other clients expecting case-sensitive header names.
//
// By default request and response header names are normalized, i.e.
// By default, request and response header names are normalized, i.e.
// The first letter and the first letters following dashes
// are uppercased, while all the other letters are lowercased.
// Examples:
//
// * HOST -> Host
// * content-type -> Content-Type
// * cONTENT-lenGTH -> Content-Length
DisableHeaderNamesNormalizing bool
DisableRequestHeaderNamesNormalizing bool

DisableResponseHeaderNamesNormalizing bool
}

func (o *Options) Apply(opts []Option) {
Expand Down Expand Up @@ -243,7 +245,9 @@ func NewOptions(opts []Option) *Options {
Registry: registry.NoopRegistry,

// Disabled header names' normalization, default false
DisableHeaderNamesNormalizing: false,
DisableRequestHeaderNamesNormalizing: false,

DisableResponseHeaderNamesNormalizing: false,
}
options.Apply(opts)
return options
Expand Down
2 changes: 2 additions & 0 deletions pkg/common/config/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func TestDefaultOptions(t *testing.T) {
assert.DeepEqual(t, []interface{}{}, options.Tracers)
assert.DeepEqual(t, new(interface{}), options.TraceLevel)
assert.DeepEqual(t, registry.NoopRegistry, options.Registry)
assert.DeepEqual(t, false, options.DisableRequestHeaderNamesNormalizing)
assert.DeepEqual(t, false, options.DisableResponseHeaderNamesNormalizing)
}

// TestApplyCustomOptions test apply options with custom values after init
Expand Down
34 changes: 14 additions & 20 deletions pkg/protocol/http1/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,20 @@ var (
)

type Option struct {
StreamRequestBody bool
GetOnly bool
DisablePreParseMultipartForm bool
DisableKeepalive bool
NoDefaultServerHeader bool
MaxRequestBodySize int
IdleTimeout time.Duration
ReadTimeout time.Duration
ServerName []byte
TLS *tls.Config
HTMLRender render.HTMLRender
EnableTrace bool
ContinueHandler func(header *protocol.RequestHeader) bool
HijackConnHandle func(c network.Conn, h app.HijackHandler)
DisableHeaderNamesNormalizing bool
StreamRequestBody bool
GetOnly bool
DisablePreParseMultipartForm bool
DisableKeepalive bool
NoDefaultServerHeader bool
MaxRequestBodySize int
IdleTimeout time.Duration
ReadTimeout time.Duration
ServerName []byte
TLS *tls.Config
HTMLRender render.HTMLRender
EnableTrace bool
ContinueHandler func(header *protocol.RequestHeader) bool
HijackConnHandle func(c network.Conn, h app.HijackHandler)
}

type Server struct {
Expand Down Expand Up @@ -181,11 +180,6 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) {
})
}

if s.DisableHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Response.Header.DisableNormalizing()
}

// Read Headers
if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil {
if s.EnableTrace {
Expand Down
37 changes: 22 additions & 15 deletions pkg/route/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,14 @@ func (engine *Engine) allocateContext() *app.RequestContext {
ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize)
ctx.SetClientIPFunc(engine.clientIPFunc)
ctx.SetFormValueFunc(engine.formValueFunc)
if engine.options.DisableRequestHeaderNamesNormalizing {
ctx.Request.Header.DisableNormalizing()
ctx.Request.Header.Trailer().DisableNormalizing()
}
if engine.options.DisableResponseHeaderNamesNormalizing {
ctx.Response.Header.DisableNormalizing()
ctx.Response.Header.Trailer().DisableNormalizing()
}
return ctx
}

Expand Down Expand Up @@ -990,21 +998,20 @@ func iterate(method string, routes RoutesInfo, root *node) RoutesInfo {
// for built-in http1 impl only.
func newHttp1OptionFromEngine(engine *Engine) *http1.Option {
opt := &http1.Option{
StreamRequestBody: engine.options.StreamRequestBody,
GetOnly: engine.options.GetOnly,
DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm,
DisableKeepalive: engine.options.DisableKeepalive,
NoDefaultServerHeader: engine.options.NoDefaultServerHeader,
MaxRequestBodySize: engine.options.MaxRequestBodySize,
IdleTimeout: engine.options.IdleTimeout,
ReadTimeout: engine.options.ReadTimeout,
ServerName: engine.GetServerName(),
ContinueHandler: engine.ContinueHandler,
TLS: engine.options.TLS,
HTMLRender: engine.htmlRender,
EnableTrace: engine.IsTraceEnable(),
HijackConnHandle: engine.HijackConnHandle,
DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing,
StreamRequestBody: engine.options.StreamRequestBody,
GetOnly: engine.options.GetOnly,
DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm,
DisableKeepalive: engine.options.DisableKeepalive,
NoDefaultServerHeader: engine.options.NoDefaultServerHeader,
MaxRequestBodySize: engine.options.MaxRequestBodySize,
IdleTimeout: engine.options.IdleTimeout,
ReadTimeout: engine.options.ReadTimeout,
ServerName: engine.GetServerName(),
ContinueHandler: engine.ContinueHandler,
TLS: engine.options.TLS,
HTMLRender: engine.htmlRender,
EnableTrace: engine.IsTraceEnable(),
HijackConnHandle: engine.HijackConnHandle,
}
// Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero.
// Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons.
Expand Down

0 comments on commit 935bb29

Please sign in to comment.