Skip to content

Commit

Permalink
feat: remove normalize for header bind
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed Sep 22, 2023
1 parent a6e4159 commit 256f9bb
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 36 deletions.
74 changes: 74 additions & 0 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,80 @@ func TestBind_InterfaceType(t *testing.T) {
}
}

func Test_BindHeaderNormalize(t *testing.T) {
type Req struct {
Header string `header:"h"`
}

req := newMockRequest().
SetRequestURI("http://foobar.com").
SetHeaders("h", "header")
var result Req

err := DefaultBinder().Bind(req.Req, &result, nil)
if err != nil {
t.Error(err)
}
assert.DeepEqual(t, "header", result.Header)
req = newMockRequest().
SetRequestURI("http://foobar.com").
SetHeaders("H", "header")
err = DefaultBinder().Bind(req.Req, &result, nil)
if err != nil {
t.Error(err)
}
assert.DeepEqual(t, "header", result.Header)

type Req2 struct {
Header string `header:"H"`
}

req2 := newMockRequest().
SetRequestURI("http://foobar.com").
SetHeaders("h", "header")
var result2 Req2

err2 := DefaultBinder().Bind(req2.Req, &result2, nil)
if err != nil {
t.Error(err2)
}
assert.DeepEqual(t, "header", result2.Header)
req2 = newMockRequest().
SetRequestURI("http://foobar.com").
SetHeaders("H", "header")
err2 = DefaultBinder().Bind(req2.Req, &result2, nil)
if err2 != nil {
t.Error(err2)
}
assert.DeepEqual(t, "header", result2.Header)

type Req3 struct {
Header string `header:"h"`
}

// without normalize, the header key & tag key need to be consistent
req3 := newMockRequest().
SetRequestURI("http://foobar.com")
req3.Req.Header.DisableNormalizing()
req3.SetHeaders("h", "header")
var result3 Req3
err3 := DefaultBinder().Bind(req3.Req, &result3, nil)
if err3 != nil {
t.Error(err3)
}
assert.DeepEqual(t, "header", result3.Header)
req3 = newMockRequest().
SetRequestURI("http://foobar.com")
req3.Req.Header.DisableNormalizing()
req3.SetHeaders("H", "header")
result3 = Req3{}
err3 = DefaultBinder().Bind(req3.Req, &result3, nil)
if err3 != nil {
t.Error(err3)
}
assert.DeepEqual(t, "", result3.Header)
}

func Benchmark_Binding(b *testing.B) {
type Req struct {
Version string `path:"v"`
Expand Down
4 changes: 0 additions & 4 deletions pkg/app/server/binding/internal/decoder/base_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import (
"fmt"
"reflect"

"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/route/param"
)
Expand Down Expand Up @@ -81,9 +80,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
}
continue
}
if tagInfo.Key == headerTag {
tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing())
}
text, exist = tagInfo.Getter(req, params, tagInfo.Value)
defaultValue = tagInfo.Default
if exist {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ package decoder
import (
"reflect"

"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/route/param"
)
Expand All @@ -64,9 +63,6 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.
defaultValue = tagInfo.Default
continue
}
if tagInfo.Key == headerTag {
tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing())
}
text, exist = tagInfo.Getter(req, params, tagInfo.Value)
defaultValue = tagInfo.Default
if exist {
Expand Down
4 changes: 0 additions & 4 deletions pkg/app/server/binding/internal/decoder/map_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ import (

"github.com/cloudwego/hertz/internal/bytesconv"
hJson "github.com/cloudwego/hertz/pkg/common/json"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/route/param"
)
Expand All @@ -73,9 +72,6 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par
}
continue
}
if tagInfo.Key == headerTag {
tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing())
}
text, exist = tagInfo.Getter(req, params, tagInfo.Value)
defaultValue = tagInfo.Default
if exist {
Expand Down
4 changes: 0 additions & 4 deletions pkg/app/server/binding/internal/decoder/slice_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ import (

"github.com/cloudwego/hertz/internal/bytesconv"
hJson "github.com/cloudwego/hertz/pkg/common/json"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/route/param"
)
Expand Down Expand Up @@ -75,9 +74,6 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
}
continue
}
if tagInfo.Key == headerTag {
tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing())
}
if tagInfo.Key == rawBodyTag {
bindRawBody = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/cloudwego/hertz/internal/bytesconv"
"github.com/cloudwego/hertz/pkg/common/hlog"
hjson "github.com/cloudwego/hertz/pkg/common/json"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/route/param"
)
Expand All @@ -50,9 +49,6 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.
}
continue
}
if tagInfo.Key == headerTag {
tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing())
}
text, exist = tagInfo.Getter(req, params, tagInfo.Value)
defaultValue = tagInfo.Default
if exist {
Expand Down
6 changes: 0 additions & 6 deletions pkg/common/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,6 @@ func CaseInsensitiveCompare(a, b []byte) bool {
return true
}

func GetNormalizeHeaderKey(key string, disableNormalizing bool) string {
keyBytes := []byte(key)
NormalizeHeaderKey(keyBytes, disableNormalizing)
return string(keyBytes)
}

func NormalizeHeaderKey(b []byte, disableNormalizing bool) {
if disableNormalizing {
return
Expand Down
10 changes: 0 additions & 10 deletions pkg/common/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,3 @@ func TestFilterContentType(t *testing.T) {
contentType = FilterContentType(contentType)
assert.DeepEqual(t, "text/plain", contentType)
}

func TestGetNormalizeHeaderKey(t *testing.T) {
key := "content-type"
key = GetNormalizeHeaderKey(key, false)
assert.DeepEqual(t, "Content-Type", key)

key = "content-type"
key = GetNormalizeHeaderKey(key, true)
assert.DeepEqual(t, "content-type", key)
}

0 comments on commit 256f9bb

Please sign in to comment.