Skip to content

Commit

Permalink
Merge branch 'develop' into test_pkg_common_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardo3202 authored Jun 11, 2024
2 parents d97ec8a + 7437071 commit d5dcbc8
Show file tree
Hide file tree
Showing 21 changed files with 202 additions and 53 deletions.
5 changes: 5 additions & 0 deletions cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ func Init() *cli.App {
customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData}
customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage}
handlerByMethod := cli.BoolFlag{Name: "handler_by_method", Usage: "Generate a separate handler file for each method.", Destination: &globalArgs.HandlerByMethod}
trimGoPackage := cli.StringFlag{Name: "trim_gopackage", Aliases: []string{"trim_pkg"}, Usage: "Trim the prefix of go_package for protobuf.", Destination: &globalArgs.TrimGoPackage}

// app
app := cli.NewApp()
Expand Down Expand Up @@ -225,6 +226,7 @@ func Init() *cli.App {
&thriftOptionsFlag,
&protoOptionsFlag,
&optPkgFlag,
&trimGoPackage,
&noRecurseFlag,
&forceNewFlag,
&enableExtendsFlag,
Expand Down Expand Up @@ -261,6 +263,7 @@ func Init() *cli.App {
&thriftOptionsFlag,
&protoOptionsFlag,
&optPkgFlag,
&trimGoPackage,
&noRecurseFlag,
&enableExtendsFlag,
&sortRouterFlag,
Expand Down Expand Up @@ -291,6 +294,7 @@ func Init() *cli.App {
&thriftOptionsFlag,
&protoOptionsFlag,
&noRecurseFlag,
&trimGoPackage,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down Expand Up @@ -318,6 +322,7 @@ func Init() *cli.App {
&protoOptionsFlag,
&noRecurseFlag,
&enableExtendsFlag,
&trimGoPackage,

&jsonEnumStrFlag,
&queryEnumIntFlag,
Expand Down
13 changes: 7 additions & 6 deletions cmd/hz/config/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ type Argument struct {
BaseDomain string // request domain
ForceClientDir string // client dir (not use namespace as a subpath)

IdlType string // idl type
IdlPaths []string // master idl path
RawOptPkg []string // user-specified package import path
OptPkgMap map[string]string
Includes []string
PkgPrefix string
IdlType string // idl type
IdlPaths []string // master idl path
RawOptPkg []string // user-specified package import path
OptPkgMap map[string]string
Includes []string
PkgPrefix string
TrimGoPackage string // trim go_package for protobuf, avoid to generate multiple directory

Gopath string // $GOPATH
Gosrc string // $GOPATH/src
Expand Down
6 changes: 3 additions & 3 deletions cmd/hz/protobuf/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if proto.HasExtension(f.Desc.Options(), api.E_Path) {
hasAnnotation = true
pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path)
val := checkSnakeName(pathAnnos.(string))
val := pathAnnos.(string)
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
Expand All @@ -323,7 +323,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if proto.HasExtension(f.Desc.Options(), api.E_Header) {
hasAnnotation = true
headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header)
val := checkSnakeName(headerAnnos.(string))
val := headerAnnos.(string)
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
} else {
Expand All @@ -350,7 +350,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen
if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil {
hasAnnotation = true
hasFormAnnotation = true
val := checkSnakeName(fileAnnos.(string))
val := fileAnnos.(string)
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName)
}
if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") {
Expand Down
8 changes: 6 additions & 2 deletions cmd/hz/protobuf/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (plugin *Plugin) Response(resp *pluginpb.CodeGeneratorResponse) error {
}

func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Argument) error {
plugin.fixGoPackage(req, plugin.PkgMap)
plugin.fixGoPackage(req, plugin.PkgMap, args.TrimGoPackage)

// new plugin
opts := protogen.Options{}
Expand Down Expand Up @@ -291,12 +291,16 @@ func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Ar
}

// fixGoPackage will update go_package to store all the model files in ${model_dir}
func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string) {
func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string, trimGoPackage string) {
gopkg := plugin.Package
for _, f := range req.ProtoFile {
if strings.HasPrefix(f.GetPackage(), "google.protobuf") {
continue
}
if len(trimGoPackage) != 0 && strings.HasPrefix(f.GetOptions().GetGoPackage(), trimGoPackage) {
*f.Options.GoPackage = strings.TrimPrefix(*f.Options.GoPackage, trimGoPackage)
}

opt := getGoPackage(f, pkgMap)
if !strings.Contains(opt, gopkg) {
if strings.HasPrefix(opt, "/") {
Expand Down
6 changes: 3 additions & 3 deletions cmd/hz/thrift/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ

if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 {
hasAnnotation = true
path := checkSnakeName(anno[0])
path := anno[0]
if isStringFieldType {
clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String())
} else {
Expand All @@ -291,7 +291,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ

if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 {
hasAnnotation = true
header := checkSnakeName(anno[0])
header := anno[0]
if isStringFieldType {
clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String())
} else {
Expand All @@ -317,7 +317,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ

if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 {
hasAnnotation = true
fileName := checkSnakeName(anno[0])
fileName := anno[0]
hasFormAnnotation = true
clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String())
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7
github.com/bytedance/mockey v1.2.1
github.com/bytedance/sonic v1.8.1
github.com/cloudwego/netpoll v0.5.0
github.com/cloudwego/netpoll v0.6.0
github.com/fsnotify/fsnotify v1.5.4
github.com/tidwall/gjson v1.14.4
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU=
github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/cloudwego/netpoll v0.6.0 h1:JRMkrA1o8k/4quxzg6Q1XM+zIhwZsyoWlq6ef+ht31U=
github.com/cloudwego/netpoll v0.6.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
10 changes: 8 additions & 2 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,14 @@ func TestBind_RequiredBind(t *testing.T) {
A int `query:"a,required"`
}
req := newMockRequest().
SetRequestURI("http://foobar.com")
err := DefaultBinder().Bind(req.Req, &s, nil)
assert.DeepEqual(t, "'a' field is a 'required' parameter, but the request does not have this parameter", err.Error())

req = newMockRequest().
SetRequestURI("http://foobar.com").
SetHeader("A", "1")

err := DefaultBinder().Bind(req.Req, &s, nil)
err = DefaultBinder().Bind(req.Req, &s, nil)
if err == nil {
t.Fatal("expected error")
}
Expand Down Expand Up @@ -904,6 +908,7 @@ func TestBind_JSONRequiredField(t *testing.T) {
if err == nil {
t.Errorf("expected an error, but get nil")
}
assert.DeepEqual(t, "'c' field is a 'required' parameter, but the request body does not have this parameter 'n.n2.c'", err.Error())
assert.DeepEqual(t, 1, result.N.A)
assert.DeepEqual(t, 2, result.N.B)
assert.DeepEqual(t, 0, result.N.N2.C)
Expand Down Expand Up @@ -1492,6 +1497,7 @@ func Test_ValidatorErrorFactory(t *testing.T) {
if err == nil {
t.Fatalf("unexpected nil, expected an error")
}
assert.DeepEqual(t, "'a' field is a 'required' parameter, but the request does not have this parameter", err.Error())

type TestValidate struct {
B int `query:"b" vd:"$>100"`
Expand Down
4 changes: 2 additions & 2 deletions pkg/app/server/binding/internal/decoder/base_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", d.fieldName, tagInfo.JSONName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", tagInfo.Value, tagInfo.JSONName)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
Expand All @@ -90,7 +90,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa
break
}
if tagInfo.Required {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
}
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/app/server/binding/internal/decoder/map_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
Expand All @@ -82,7 +82,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par
break
}
if tagInfo.Required {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
}
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/app/server/binding/internal/decoder/slice_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { //
defaultValue = ""
Expand All @@ -88,7 +88,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P
break
}
if tagInfo.Required {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
}
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.
if found {
err = nil
} else {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) {
defaultValue = ""
Expand All @@ -59,7 +59,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.
break
}
if tagInfo.Required {
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName)
err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value)
}
}
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/app/server/binding/tagexpr_bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestGetBody(t *testing.T) {
if err == nil {
t.Fatalf("expected an error, but get nil")
}
assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request body does not have this parameter 'X.e'")
assert.DeepEqual(t, err.Error(), "'e' field is a 'required' parameter, but the request body does not have this parameter 'X.e'")
}

func TestQueryNum(t *testing.T) {
Expand Down Expand Up @@ -431,7 +431,7 @@ func TestJSON(t *testing.T) {
if err == nil {
t.Error("expected an error, but get nil")
}
assert.DeepEqual(t, err.Error(), "'Y' field is a 'required' parameter, but the request body does not have this parameter 'y'")
assert.DeepEqual(t, err.Error(), "'y' field is a 'required' parameter, but the request body does not have this parameter 'y'")
assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A)
assert.DeepEqual(t, int32(21), (**recv.X).B)
assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C)
Expand Down Expand Up @@ -753,7 +753,7 @@ func TestOption(t *testing.T) {
req = newRequest("", header, nil, bodyReader)
recv = new(Recv)
err = DefaultBinder().Bind(req.Req, recv, nil)
assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request body does not have this parameter 'X.c'")
assert.DeepEqual(t, err.Error(), "'c' field is a 'required' parameter, but the request body does not have this parameter 'X.c'")
assert.DeepEqual(t, 0, recv.X.C)
assert.DeepEqual(t, 0, recv.X.D)
assert.DeepEqual(t, "y1", recv.Y)
Expand Down
55 changes: 55 additions & 0 deletions pkg/app/server/hertz_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
c "github.com/cloudwego/hertz/pkg/app/client"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/network"
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"golang.org/x/sys/unix"
Expand Down Expand Up @@ -134,3 +135,57 @@ func TestHertz_Spin(t *testing.T) {

<-ch2
}

func TestWithSenseClientDisconnection(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6631"), WithSenseClientDisconnection(true))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6631")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}

func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) {
var closeFlag int32
h := New(WithHostPorts("127.0.0.1:6632"), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context {
return ctx
}))
h.GET("/ping", func(c context.Context, ctx *app.RequestContext) {
assert.DeepEqual(t, "aa", string(ctx.Host()))
ch := make(chan struct{})
select {
case <-c.Done():
atomic.StoreInt32(&closeFlag, 1)
assert.DeepEqual(t, context.Canceled, c.Err())
case <-ch:
}
})
go h.Spin()
time.Sleep(time.Second)
con, err := net.Dial("tcp", "127.0.0.1:6632")
assert.Nil(t, err)
_, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n"))
assert.Nil(t, err)
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0))
assert.Nil(t, con.Close())
time.Sleep(time.Second)
assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1))
}
16 changes: 16 additions & 0 deletions pkg/app/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,19 @@ func WithDisableDefaultContentType(disable bool) config.Option {
o.NoDefaultContentType = disable
}}
}

// WithSenseClientDisconnection sets the ability to sense client disconnections.
// If we don't set it, it will default to false.
// There are two issues to note when using this option:
// 1. Warning: It only applies to netpoll.
// 2. After opening, the context.Context in the request will be cancelled.
//
// Example:
// server.Default(
// server.WithSenseClientDisconnection(true),
// )
func WithSenseClientDisconnection(b bool) config.Option {
return config.Option{F: func(o *config.Options) {
o.SenseClientDisconnection = b
}}
}
Loading

0 comments on commit d5dcbc8

Please sign in to comment.