diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1099f2a06..9f387556f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -21,16 +21,17 @@ The description of the title will be attached in Release Notes, so please describe it from user-oriented, what this PR does / why we need it. Please check your PR title with the below requirements: --> -- [ ] This PR title match the format: `(optional scope): `. +- [ ] This PR title match the format: \(optional scope): \ - [ ] The description of this PR title is user-oriented and clear enough for others to understand. -- [ ] Attach the PR updating the user documentation if the current PR requires user awareness at the usage level. [User docs repo](https://github.com/cloudwego/cloudwego.github.io). +- [ ] Attach the PR updating the user documentation if the current PR requires user awareness at the usage level. [User docs repo](https://github.com/cloudwego/cloudwego.github.io) + #### (Optional) Translate the PR title into Chinese. -#### (Optional) More detail description for this PR(en: English/zh: Chinese). +#### (Optional) More detailed description for this PR(en: English/zh: Chinese). en: zh(optional): @@ -44,4 +45,4 @@ Eg: `Fixes #`, or `Fixes (paste link of issue)`. #### (Optional) The PR that updates user documentation: +--> \ No newline at end of file diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index 4d797d26e..2a949c5d6 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -44,6 +44,7 @@ type HttpMethod struct { RefPackage string // handler import dir RefPackageAlias string // handler import alias ModelPackage map[string]string + GenHandler bool // Whether to generate one handler, when an idl interface corresponds to multiple http method // Annotations map[string]string Models map[string]*model.Model } @@ -78,8 +79,10 @@ func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, han return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } - if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { - return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + if m.GenHandler { + if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { + return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) + } } } } else { // generate handler service @@ -105,6 +108,15 @@ func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, han return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } + // Avoid generating duplicate handlers when IDL interface corresponds to multiple http methods + methods := handler.Methods + handler.Methods = []*HttpMethod{} + for _, m := range methods { + if m.GenHandler { + handler.Methods = append(handler.Methods, m) + } + } + if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index 8e7490dfe..abda7c71c 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -19,7 +19,7 @@ package meta import "runtime" // Version hz version -const Version = "v0.6.6" +const Version = "v0.6.7" const DefaultServiceName = "hertz_service" diff --git a/cmd/hz/protobuf/api/api.pb.go b/cmd/hz/protobuf/api/api.pb.go index 9707626a4..4c1d8c66c 100644 --- a/cmd/hz/protobuf/api/api.pb.go +++ b/cmd/hz/protobuf/api/api.pb.go @@ -1,16 +1,17 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.28.1 +// protoc-gen-go v1.30.0 // protoc v3.21.12 // source: api.proto package api import ( + reflect "reflect" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" descriptorpb "google.golang.org/protobuf/types/descriptorpb" - reflect "reflect" ) const ( @@ -317,6 +318,22 @@ var file_api_proto_extTypes = []protoimpl.ExtensionInfo{ Tag: "bytes,50731,opt,name=base_domain_compatible", Filename: "api.proto", }, + { + ExtendedType: (*descriptorpb.ServiceOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50732, + Name: "api.service_path", + Tag: "bytes,50732,opt,name=service_path", + Filename: "api.proto", + }, + { + ExtendedType: (*descriptorpb.MessageOptions)(nil), + ExtensionType: (*string)(nil), + Field: 50830, + Name: "api.reserve", + Tag: "bytes,50830,opt,name=reserve", + Filename: "api.proto", + }, } // Extension fields to descriptorpb.FieldOptions. @@ -413,6 +430,14 @@ var ( // // optional string base_domain_compatible = 50731; E_BaseDomainCompatible = &file_api_proto_extTypes[36] + // optional string service_path = 50732; + E_ServicePath = &file_api_proto_extTypes[37] +) + +// Extension fields to descriptorpb.MessageOptions. +var ( + // optional string reserve = 50830; + E_Reserve = &file_api_proto_extTypes[38] ) var File_api_proto protoreflect.FileDescriptor @@ -563,7 +588,15 @@ var file_api_proto_rawDesc = []byte{ 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xab, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, 0x62, 0x61, 0x73, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6f, 0x6d, 0x70, 0x61, - 0x74, 0x69, 0x62, 0x6c, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2f, 0x61, 0x70, 0x69, + 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x44, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xac, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x3b, 0x0a, 0x07, 0x72, + 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x8e, 0x8d, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2f, 0x61, 0x70, 0x69, } var file_api_proto_goTypes = []interface{}{ @@ -571,6 +604,7 @@ var file_api_proto_goTypes = []interface{}{ (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions (*descriptorpb.EnumValueOptions)(nil), // 2: google.protobuf.EnumValueOptions (*descriptorpb.ServiceOptions)(nil), // 3: google.protobuf.ServiceOptions + (*descriptorpb.MessageOptions)(nil), // 4: google.protobuf.MessageOptions } var file_api_proto_depIdxs = []int32{ 0, // 0: api.raw_body:extendee -> google.protobuf.FieldOptions @@ -610,10 +644,12 @@ var file_api_proto_depIdxs = []int32{ 2, // 34: api.http_code:extendee -> google.protobuf.EnumValueOptions 3, // 35: api.base_domain:extendee -> google.protobuf.ServiceOptions 3, // 36: api.base_domain_compatible:extendee -> google.protobuf.ServiceOptions - 37, // [37:37] is the sub-list for method output_type - 37, // [37:37] is the sub-list for method input_type - 37, // [37:37] is the sub-list for extension type_name - 0, // [0:37] is the sub-list for extension extendee + 3, // 37: api.service_path:extendee -> google.protobuf.ServiceOptions + 4, // 38: api.reserve:extendee -> google.protobuf.MessageOptions + 39, // [39:39] is the sub-list for method output_type + 39, // [39:39] is the sub-list for method input_type + 39, // [39:39] is the sub-list for extension type_name + 0, // [0:39] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } @@ -629,7 +665,7 @@ func file_api_proto_init() { RawDescriptor: file_api_proto_rawDesc, NumEnums: 0, NumMessages: 0, - NumExtensions: 37, + NumExtensions: 39, NumServices: 0, }, GoTypes: file_api_proto_goTypes, diff --git a/cmd/hz/protobuf/api/api.proto b/cmd/hz/protobuf/api/api.proto index 72edf41ff..71a3cd307 100644 --- a/cmd/hz/protobuf/api/api.proto +++ b/cmd/hz/protobuf/api/api.proto @@ -64,6 +64,7 @@ extend google.protobuf.ServiceOptions { // 50731~50760 used to extend service option by hz optional string base_domain_compatible = 50731; + optional string service_path = 50732; } extend google.protobuf.MessageOptions { diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go index fcb8197eb..264dc4e30 100644 --- a/cmd/hz/protobuf/ast.go +++ b/cmd/hz/protobuf/ast.go @@ -19,6 +19,7 @@ package protobuf import ( "fmt" "path/filepath" + "sort" "strings" "github.com/cloudwego/hertz/cmd/hz/generator" @@ -129,12 +130,25 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd ms := s.GetMethod() methods := make([]*generator.HttpMethod, 0, len(ms)) clientMethods := make([]*generator.ClientMethod, 0, len(ms)) + servicePathAnno := checkFirstOption(api.E_ServicePath, s.GetOptions()) + servicePath := "" + if val, ok := servicePathAnno.(string); ok { + servicePath = val + } for _, m := range ms { - hmethod, vpath := checkFirstOptions(HttpMethodOptions, m.GetOptions()) - if hmethod == "" { + rs := getAllOptions(HttpMethodOptions, m.GetOptions()) + if len(rs) == 0 { continue } - path := vpath.(string) + httpOpts := httpOptions{} + for k, v := range rs { + httpOpts = append(httpOpts, httpOption{ + method: k, + path: v.(string), + }) + } + // turn the map into a slice and sort it to make sure getting the results in the same order every time + sort.Sort(httpOpts) var handlerOutDir string genPath := getCompatibleAnnotation(m.GetOptions(), api.E_HandlerPath, api.E_HandlerPathCompatible) @@ -142,6 +156,9 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd if !ok || len(handlerOutDir) == 0 { handlerOutDir = "" } + if len(handlerOutDir) == 0 { + handlerOutDir = servicePath + } // protoGoInfo can get generated "Go Info" for proto file. // the type name may be different between "***.proto" and "***.pb.go" @@ -181,10 +198,11 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd method := &generator.HttpMethod{ Name: util.CamelString(m.GetName()), - HTTPMethod: hmethod, - Path: path, + HTTPMethod: httpOpts[0].method, + Path: httpOpts[0].path, Serializer: serializer, OutputDir: handlerOutDir, + GenHandler: true, } goOptMapAlias := make(map[string]string, 1) @@ -223,6 +241,16 @@ func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmd method.ReturnTypePackage = respPackage methods = append(methods, method) + for idx, anno := range httpOpts { + if idx == 0 { + continue + } + tmp := *method + tmp.HTTPMethod = anno.method + tmp.Path = anno.path + tmp.GenHandler = false + methods = append(methods, &tmp) + } if cmdType == meta.CmdClient { clientMethod := &generator.ClientMethod{} diff --git a/cmd/hz/protobuf/tags.go b/cmd/hz/protobuf/tags.go index 23c046490..1abf23111 100644 --- a/cmd/hz/protobuf/tags.go +++ b/cmd/hz/protobuf/tags.go @@ -94,6 +94,38 @@ var ( SerializerOptions = map[*protoimpl.ExtensionInfo]string{api.E_Serializer: "serializer"} ) +type httpOption struct { + method string + path string +} + +type httpOptions []httpOption + +func (s httpOptions) Len() int { + return len(s) +} + +func (s httpOptions) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s httpOptions) Less(i, j int) bool { + return s[i].method < s[j].method +} + +func getAllOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) map[string]interface{} { + out := map[string]interface{}{} + for _, opt := range opts { + for e, t := range extensions { + if proto.HasExtension(opt, e) { + v := proto.GetExtension(opt, e) + out[t] = v + } + } + } + return out +} + func checkFirstOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) (string, interface{}) { for _, opt := range opts { for e, t := range extensions { diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go index 557b50a41..f634d9605 100644 --- a/cmd/hz/thrift/ast.go +++ b/cmd/hz/thrift/ast.go @@ -18,6 +18,7 @@ package thrift import ( "fmt" + "sort" "strings" "github.com/cloudwego/hertz/cmd/hz/config" @@ -98,26 +99,34 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) } methods := make([]*generator.HttpMethod, 0, len(ms)) clientMethods := make([]*generator.ClientMethod, 0, len(ms)) + servicePathAnno := getAnnotation(s.Annotations, ApiServicePath) + servicePath := "" + if len(servicePathAnno) > 0 { + servicePath = servicePathAnno[0] + } for _, m := range ms { rs := getAnnotations(m.Annotations, HttpMethodAnnotations) - if len(rs) > 1 { - return nil, fmt.Errorf("too many 'api.XXX' annotations: %s", rs) - } if len(rs) == 0 { continue } - - var handlerOutDir string + httpAnnos := httpAnnotations{} + for k, v := range rs { + httpAnnos = append(httpAnnos, httpAnnotation{ + method: k, + path: v, + }) + } + // turn the map into a slice and sort it to make sure getting the results in the same order every time + sort.Sort(httpAnnos) + handlerOutDir := servicePath genPaths := getAnnotation(m.Annotations, ApiGenPath) - if len(genPaths) == 0 { - handlerOutDir = "" - } else if len(genPaths) > 1 { - return nil, fmt.Errorf("too many 'api.handler_path' for %s", m.Name) - } else { + if len(genPaths) == 1 { handlerOutDir = genPaths[0] + } else if len(genPaths) > 0 { + return nil, fmt.Errorf("too many 'api.handler_path' for %s", m.Name) } - hmethod, path := util.GetFirstKV(rs) + hmethod, path := httpAnnos[0].method, httpAnnos[0].path if len(path) != 1 || path[0] == "" { return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path) } @@ -175,6 +184,7 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) Path: path[0], Serializer: sr, OutputDir: handlerOutDir, + GenHandler: true, // Annotations: m.Annotations, } refs := resolver.ExportReferred(false, true) @@ -187,6 +197,20 @@ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) } models.MergeMap(method.Models) methods = append(methods, method) + for idx, anno := range httpAnnos { + if idx == 0 { + continue + } + tmp := *method + hmethod, path := anno.method, anno.path + if len(path) != 1 || path[0] == "" { + return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path) + } + tmp.HTTPMethod = hmethod + tmp.Path = path[0] + tmp.GenHandler = false + methods = append(methods, &tmp) + } if args.CmdType == meta.CmdClient { clientMethod := &generator.ClientMethod{} clientMethod.HttpMethod = method @@ -353,18 +377,7 @@ func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Res } funcs := extendSvc.GetFunctions() for _, f := range funcs { - // the method of other file is extended, and the package of req/resp needs to be changed - // ex. base.thrift -> Resp Method(Req){} - // base.Resp Method(base.Req){} - // todo: support container for Struct - if len(f.Arguments) > 0 { - if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() { - f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name - } - } - if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() { - f.FunctionType.Name = base + "." + f.FunctionType.Name - } + processExtendsType(f, base) } extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args) if err != nil { @@ -392,18 +405,7 @@ func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Res if found { funcs := extendSvc.GetFunctions() for _, f := range funcs { - // the method of other file is extended, and the package of req/resp needs to be changed - // ex. base.thrift -> Resp Method(Req){} - // base.Resp Method(base.Req){} - // todo: support container for Struct - if len(f.Arguments) > 0 { - if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() { - f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name - } - } - if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() { - f.FunctionType.Name = base + "." + f.FunctionType.Name - } + processExtendsType(f, base) } extendFuncs, err := getAllExtendFunction(extendSvc, refAst, resolver, args) if err != nil { @@ -418,6 +420,53 @@ func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Res return res, nil } +func processExtendsType(f *parser.Function, base string) { + // the method of other file is extended, and the package of req/resp needs to be changed + // ex. base.thrift -> Resp Method(Req){} + // base.Resp Method(base.Req){} + if len(f.Arguments) > 0 { + if f.Arguments[0].Type.Category.IsContainerType() { + switch f.Arguments[0].Type.Category { + case parser.Category_Set, parser.Category_List: + if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { + f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name + } + case parser.Category_Map: + if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { + f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name + } + if !strings.Contains(f.Arguments[0].Type.KeyType.Name, ".") && f.Arguments[0].Type.KeyType.Category.IsStruct() { + f.Arguments[0].Type.KeyType.Name = base + "." + f.Arguments[0].Type.KeyType.Name + } + } + } else { + if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() { + f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name + } + } + } + + if f.FunctionType.Category.IsContainerType() { + switch f.FunctionType.Category { + case parser.Category_Set, parser.Category_List: + if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { + f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name + } + case parser.Category_Map: + if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { + f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name + } + if !strings.Contains(f.FunctionType.KeyType.Name, ".") && f.FunctionType.KeyType.Category.IsStruct() { + f.FunctionType.KeyType.Name = base + "." + f.FunctionType.KeyType.Name + } + } + } else { + if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() { + f.FunctionType.Name = base + "." + f.FunctionType.Name + } + } +} + func getUniqueResolveDependentName(name string, resolver *Resolver) string { rawName := name for i := 0; i < 10000; i++ { diff --git a/cmd/hz/thrift/plugin.go b/cmd/hz/thrift/plugin.go index ff4fdf65f..8b840e739 100644 --- a/cmd/hz/thrift/plugin.go +++ b/cmd/hz/thrift/plugin.go @@ -77,6 +77,8 @@ func (plugin *Plugin) Run() int { } plugin.rmTags = args.RmTags if args.CmdType == meta.CmdModel { + // check tag options for model mode + CheckTagOption(plugin.args) res, err := plugin.GetResponse(nil, args.OutDir) if err != nil { logs.Errorf("get response failed: %s", err.Error()) diff --git a/cmd/hz/thrift/tags.go b/cmd/hz/thrift/tags.go index 5b2e8b603..96c414dca 100644 --- a/cmd/hz/thrift/tags.go +++ b/cmd/hz/thrift/tags.go @@ -63,7 +63,8 @@ const ( const ( ApiBaseDomain = "api.base_domain" ApiServiceGroup = "api.service_group" - ApiServiceGenDir = "api.service_gen_dir" + ApiServiceGenDir = "api.service_gen_dir" // handler_dir for handler_by_service + ApiServicePath = "api.service_path" // declare the path to the service's handler according to this annotation for handler_by_method ) var ( @@ -141,6 +142,25 @@ func getAnnotation(input parser.Annotations, target string) []string { return []string{} } +type httpAnnotation struct { + method string + path []string +} + +type httpAnnotations []httpAnnotation + +func (s httpAnnotations) Len() int { + return len(s) +} + +func (s httpAnnotations) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s httpAnnotations) Less(i, j int) bool { + return s[i].method < s[j].method +} + func getAnnotations(input parser.Annotations, targets map[string]string) map[string][]string { if len(input) == 0 || len(targets) == 0 { return nil diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index 8449cd4d8..e7dcbf506 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2269,7 +2269,7 @@ func TestClientDoWithDialFunc(t *testing.T) { func TestClientState(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = "127.0.0.1:11000" + opt.Addr = ":10037" engine := route.NewEngine(opt) go engine.Run() @@ -2282,12 +2282,12 @@ func TestClientState(t *testing.T) { case int32(0): assert.DeepEqual(t, 1, hcs.ConnPoolState().TotalConnNum) assert.DeepEqual(t, 1, hcs.ConnPoolState().PoolConnNum) - assert.DeepEqual(t, "127.0.0.1:11000", hcs.ConnPoolState().Addr) + assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(1)) case int32(1): assert.DeepEqual(t, 0, hcs.ConnPoolState().TotalConnNum) assert.DeepEqual(t, 0, hcs.ConnPoolState().PoolConnNum) - assert.DeepEqual(t, "127.0.0.1:11000", hcs.ConnPoolState().Addr) + assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(2)) return case int32(2): @@ -2295,7 +2295,7 @@ func TestClientState(t *testing.T) { } }, time.Second*9)) - client.Get(context.Background(), nil, "http://127.0.0.1:11000") + client.Get(context.Background(), nil, "http://127.0.0.1:10037") time.Sleep(time.Second * 22) } diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 8bf52f437..7211e3910 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -17,10 +17,14 @@ package app import ( + "bytes" "context" + "encoding/xml" "errors" "fmt" + "html/template" "io/ioutil" + "os" "reflect" "strings" "testing" @@ -82,6 +86,17 @@ func TestContext(t *testing.T) { } } +func TestValue(t *testing.T) { + ctx := NewContext(0) + + v := ctx.Value("testContextKey") + assert.Nil(t, v) + + ctx.Set("testContextKey", "testValue") + v = ctx.Value("testContextKey") + assert.DeepEqual(t, "testValue", v) +} + func TestContextNotModified(t *testing.T) { reqContext := NewContext(0) reqContext.Response.SetStatusCode(consts.StatusOK) @@ -273,6 +288,10 @@ func TestQuery(t *testing.T) { t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) } + if ctx.DefaultQuery("name", "default value") != "menu" { + t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) + } + if ctx.DefaultQuery("defaultQuery", "default value") != "default value" { t.Fatalf("unexpected query: %#v, expected `default value`", ctx.Query("defaultQuery")) } @@ -436,6 +455,15 @@ tailfoobar` } } + err = ctx.SaveUploadedFile(f.File["fileaaa"][0], "TODO") + assert.Nil(t, err) + fileInfo, err := os.Stat("TODO") + assert.Nil(t, err) + assert.DeepEqual(t, "TODO", fileInfo.Name()) + assert.DeepEqual(t, f.File["fileaaa"][0].Size, fileInfo.Size()) + err = os.Remove("TODO") + assert.Nil(t, err) + ff, err := ctx.FormFile("fileaaa") if err != nil || ff == nil { t.Fatalf("unexpected error happened when ctx.FormFile()") @@ -525,6 +553,13 @@ func TestRequestContext_Header(t *testing.T) { if val != "" { t.Fatalf("unexpected %q. Expecting %q", val, "") } + + c.Header("header_key1", "header_val1") + c.Header("header_key1", "") + val = string(c.Response.Header.Peek("header_key1")) + if val != "" { + t.Fatalf("unexpected %q. Expecting %q", val, "") + } } func TestRequestContext_Keys(t *testing.T) { @@ -554,6 +589,10 @@ func TestRequestContext_Handler(t *testing.T) { if val != "123" { t.Fatalf("unexpected %v. Expecting %v", val, "123") } + + c.handlers = nil + handler := c.Handler() + assert.Nil(t, handler) } func TestRequestContext_Handlers(t *testing.T) { @@ -576,6 +615,24 @@ func TestRequestContext_HandlerName(t *testing.T) { } } +func TestNext(t *testing.T) { + c := NewContext(0) + a := 0 + + testFunc1 := func(c context.Context, ctx *RequestContext) { + a = 1 + } + testFunc3 := func(c context.Context, ctx *RequestContext) { + a = 3 + } + c.handlers = HandlersChain{testFunc1, testFunc3} + + c.Next(context.Background()) + + assert.True(t, c.index == 2) + assert.DeepEqual(t, 3, a) +} + func TestContextError(t *testing.T) { c := NewContext(0) assert.Nil(t, c.Errors) @@ -631,6 +688,77 @@ func TestRender(t *testing.T) { assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) assert.True(t, strings.Contains(string(c.Response.Body()), "test")) + + c.Reset() + c.Render(110, &render.Data{ + ContentType: "application/json; charset=utf-8", + Data: []byte("{\"test\":1}"), + }) + assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) + assert.DeepEqual(t, "", string(c.Response.Body())) + + c.Reset() + c.Render(consts.StatusNoContent, &render.Data{ + ContentType: "application/json; charset=utf-8", + Data: []byte("{\"test\":1}"), + }) + assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) + assert.DeepEqual(t, "", string(c.Response.Body())) + + c.Reset() + c.Render(consts.StatusNotModified, &render.Data{ + ContentType: "application/json; charset=utf-8", + Data: []byte("{\"test\":1}"), + }) + assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) + assert.DeepEqual(t, "", string(c.Response.Body())) +} + +func TestHTML(t *testing.T) { + c := NewContext(0) + + tmpl := template.Must(template.New(""). + Delims("{[{", "}]}"). + Funcs(template.FuncMap{}). + ParseFiles("../common/testdata/template/index.tmpl")) + + r := &render.HTMLProduction{Template: tmpl} + c.HTMLRender = r + c.HTML(consts.StatusOK, "index.tmpl", utils.H{"title": "Main website"}) + + assert.DeepEqual(t, []byte("text/html; charset=utf-8"), c.Response.Header.Peek("Content-Type")) + assert.DeepEqual(t, []byte("

Main website

"), c.Response.Body()) +} + +type xmlmap map[string]interface{} + +// Allows type H to be used with xml.Marshal +func (h xmlmap) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + start.Name = xml.Name{ + Space: "", + Local: "map", + } + if err := e.EncodeToken(start); err != nil { + return err + } + for key, value := range h { + elem := xml.StartElement{ + Name: xml.Name{Space: "", Local: key}, + Attr: []xml.Attr{}, + } + if err := e.EncodeElement(value, elem); err != nil { + return err + } + } + + return e.EncodeToken(xml.EndElement{Name: start.Name}) +} + +func TestXML(t *testing.T) { + c := NewContext(0) + c.XML(consts.StatusOK, xmlmap{"foo": "bar"}) + assert.DeepEqual(t, []byte("bar"), c.Response.Body()) + assert.DeepEqual(t, []byte("application/xml; charset=utf-8"), c.Response.Header.Peek("Content-Type")) } func TestJSON(t *testing.T) { @@ -654,6 +782,7 @@ func TestContextReset(t *testing.T) { c.Params = param.Params{param.Param{}} c.Error(errors.New("test")) // nolint: errcheck c.Set("foo", "bar") + c.Finished() c.Request.SetIsTLS(true) c.ResetWithoutConn() c.Request.URI() @@ -664,6 +793,7 @@ func TestContextReset(t *testing.T) { assert.Nil(t, c.Errors.ByType(errs.ErrorTypeAny)) assert.DeepEqual(t, 0, len(c.Params)) assert.DeepEqual(t, int8(-1), c.index) + assert.Nil(t, c.finished) } func TestContextContentType(t *testing.T) { @@ -751,6 +881,16 @@ func TestRemoteAddr(t *testing.T) { assert.DeepEqual(t, "0.0.0.0:0", addr) } +func TestRequestBodyStream(t *testing.T) { + c := NewContext(0) + s := "testRequestBodyStream" + mr := bytes.NewBufferString(s) + c.Request.SetBodyStream(mr, -1) + data, err := ioutil.ReadAll(c.RequestBodyStream()) + assert.Nil(t, err) + assert.DeepEqual(t, "testRequestBodyStream", string(data)) +} + func TestContextIsAborted(t *testing.T) { ctx := NewContext(0) assert.False(t, ctx.IsAborted()) @@ -828,6 +968,23 @@ func TestRequestCtxFormValue(t *testing.T) { if string(v) != "1" { t.Fatalf("unexpected value %q. Expecting %q", v, "1") } + + ctx.Request.Reset() + s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="f" + +fff +------WebKitFormBoundaryJwfATyF8tmxSJnLg +` + mr := bytes.NewBufferString(s) + ctx.Request.SetBodyStream(mr, -1) + ctx.Request.Header.SetContentLength(len(s)) + ctx.Request.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) + + v = ctx.FormValue("f") + if string(v) != "fff" { + t.Fatalf("unexpected value %q. Expecting %q", v, "fff") + } } func TestSetCustomFormValueFunc(t *testing.T) { @@ -1076,6 +1233,25 @@ func TestForEachKey(t *testing.T) { assert.True(t, ok) } +func TestFlush(t *testing.T) { + ctx := NewContext(0) + err := ctx.Flush() + assert.Nil(t, err) +} + +func TestConn(t *testing.T) { + ctx := NewContext(0) + + conn := mock.NewConn("") + + ctx.SetConn(conn) + connRes := ctx.GetConn() + + val1 := reflect.ValueOf(conn).Pointer() + val2 := reflect.ValueOf(connRes).Pointer() + assert.DeepEqual(t, val1, val2) +} + func TestHijackHandler(t *testing.T) { ctx := NewContext(0) handle := func(c network.Conn) { @@ -1089,6 +1265,32 @@ func TestHijackHandler(t *testing.T) { assert.DeepEqual(t, val1, val2) } +func TestGetReader(t *testing.T) { + ctx := NewContext(0) + + conn := mock.NewConn("") + + ctx.SetConn(conn) + connRes := ctx.GetReader() + + val1 := reflect.ValueOf(conn).Pointer() + val2 := reflect.ValueOf(connRes).Pointer() + assert.DeepEqual(t, val1, val2) +} + +func TestGetWriter(t *testing.T) { + ctx := NewContext(0) + + conn := mock.NewConn("") + + ctx.SetConn(conn) + connRes := ctx.GetWriter() + + val1 := reflect.ValueOf(conn).Pointer() + val2 := reflect.ValueOf(connRes).Pointer() + assert.DeepEqual(t, val1, val2) +} + func TestIndex(t *testing.T) { ctx := NewContext(0) ctx.ResetWithoutConn() diff --git a/pkg/app/fs_test.go b/pkg/app/fs_test.go index 1cd545e8c..c4ad0e426 100644 --- a/pkg/app/fs_test.go +++ b/pkg/app/fs_test.go @@ -53,6 +53,7 @@ import ( "testing" "time" + "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -223,6 +224,32 @@ func TestServeFileSmallNoReadFrom(t *testing.T) { if body != teststr { t.Fatalf("expected '%s'", teststr) } + + data := make([]byte, len([]byte(teststr))) + nn, err := reader.Read(data) + assert.DeepEqual(t, len([]byte(teststr)), nn) + assert.Nil(t, err) + assert.DeepEqual(t, teststr, string(data)) + assert.DeepEqual(t, reader.startPos, len([]byte(teststr))) + + nn, err = reader.Read(data) + assert.DeepEqual(t, 0, nn) + assert.DeepEqual(t, io.EOF, err) + + data1 := make([]byte, 2) + reader.startPos = len([]byte(teststr)) - 1 + nn, err = reader.Read(data1) + assert.DeepEqual(t, []byte("!"), []byte{data1[0]}) + assert.DeepEqual(t, 1, nn) + assert.DeepEqual(t, nil, err) + + reader.startPos = 0 + reader.ff.f = nil + buf = bytes.NewBuffer(nil) + reader.ff.dirIndex = make([]byte, len([]byte(teststr))) + n, err = reader.WriteTo(pureWriter{buf}) + assert.DeepEqual(t, int64(len(teststr)), n) + assert.Nil(t, err) } type pureWriter struct { @@ -660,3 +687,11 @@ func TestServeFileContentType(t *testing.T) { t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, r.Header.ContentType()) } } + +func TestFileSmallUpdateByteRange(t *testing.T) { + r := &fsSmallFileReader{} + err := r.UpdateByteRange(1, 1) + assert.Nil(t, err) + assert.DeepEqual(t, 1, r.startPos) + assert.DeepEqual(t, 2, r.endPos) +} diff --git a/pkg/app/middlewares/client/sd/options_test.go b/pkg/app/middlewares/client/sd/options_test.go index 44a7356f9..d0fad1232 100644 --- a/pkg/app/middlewares/client/sd/options_test.go +++ b/pkg/app/middlewares/client/sd/options_test.go @@ -20,16 +20,27 @@ import ( "context" "testing" + "github.com/cloudwego/hertz/pkg/app/client/loadbalance" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestWithCustomizedAddrs(t *testing.T) { var options []ServiceDiscoveryOption - options = append(options, WithCustomizedAddrs("127.0.0.1:8080")) + options = append(options, WithCustomizedAddrs("127.0.0.1:8080", "/tmp/unix_ss")) opts := &ServiceDiscoveryOptions{} opts.Apply(options) - assert.Assert(t, opts.Resolver.Name() == "127.0.0.1:8080") + assert.Assert(t, opts.Resolver.Name() == "127.0.0.1:8080,/tmp/unix_ss") res, err := opts.Resolver.Resolve(context.Background(), "") assert.Assert(t, err == nil) assert.Assert(t, res.Instances[0].Address().String() == "127.0.0.1:8080") + assert.Assert(t, res.Instances[1].Address().String() == "/tmp/unix_ss") +} + +func TestWithLoadBalanceOptions(t *testing.T) { + balance := loadbalance.NewWeightedBalancer() + var options []ServiceDiscoveryOption + options = append(options, WithLoadBalanceOptions(balance, loadbalance.DefaultLbOpts)) + opts := &ServiceDiscoveryOptions{} + opts.Apply(options) + assert.Assert(t, opts.Balancer.Name() == "weight_random") } diff --git a/pkg/app/server/binding/binding_test.go b/pkg/app/server/binding/binding_test.go index 84c5980c5..a9050f940 100644 --- a/pkg/app/server/binding/binding_test.go +++ b/pkg/app/server/binding/binding_test.go @@ -164,6 +164,19 @@ func TestJsonBind(t *testing.T) { // NOTE: The default does not support string to go int conversion in json. // You can add "string" tags or use other json unmarshal libraries that support this feature assert.DeepEqual(t, 100, req.D) + + req = Test{} + UseStdJSONUnmarshaler() + err = BindAndValidate(r, &req, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + assert.DeepEqual(t, "aaa", req.A) + assert.DeepEqual(t, 2, len(req.B)) + assert.DeepEqual(t, "ccc", req.C) + // NOTE: The default does not support string to go int conversion in json. + // You can add "string" tags or use other json unmarshal libraries that support this feature + assert.DeepEqual(t, 100, req.D) } // TestQueryParamInconsistency tests the Inconsistency for GetQuery(), the other unit test for GetFunc() in request.go are similar to it diff --git a/pkg/app/server/binding/request_test.go b/pkg/app/server/binding/request_test.go new file mode 100644 index 000000000..b3bb70523 --- /dev/null +++ b/pkg/app/server/binding/request_test.go @@ -0,0 +1,235 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "bytes" + "fmt" + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +func TestGetQuery(t *testing.T) { + r := protocol.NewRequest("GET", "/foo", nil) + r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") + + bindReq := bindRequest{ + req: r, + } + + values := bindReq.GetQuery() + + assert.DeepEqual(t, []string{"hertz"}, values["para1"]) + assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) + assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) +} + +func TestGetPostForm(t *testing.T) { + data := "a=aaa&b=b1&b=b2&c=ccc&d=100" + mr := bytes.NewBufferString(data) + + r := protocol.NewRequest("POST", "/foo", mr) + r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) + r.Header.SetContentLength(len(data)) + + bindReq := bindRequest{ + req: r, + } + + values, err := bindReq.GetPostForm() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.DeepEqual(t, []string{"aaa"}, values["a"]) + assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) + assert.DeepEqual(t, []string{"ccc"}, values["c"]) + assert.DeepEqual(t, []string{"100"}, values["d"]) +} + +func TestGetForm(t *testing.T) { + data := "a=aaa&b=b1&b=b2&c=ccc&d=100" + mr := bytes.NewBufferString(data) + + r := protocol.NewRequest("POST", "/foo", mr) + r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") + r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) + r.Header.SetContentLength(len(data)) + + bindReq := bindRequest{ + req: r, + } + + values, err := bindReq.GetForm() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.DeepEqual(t, []string{"aaa"}, values["a"]) + assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) + assert.DeepEqual(t, []string{"ccc"}, values["c"]) + assert.DeepEqual(t, []string{"100"}, values["d"]) + assert.DeepEqual(t, []string{"hertz"}, values["para1"]) + assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) + assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) +} + +func TestGetCookies(t *testing.T) { + r := protocol.NewRequest("POST", "/foo", nil) + r.SetCookie("cookie1", "cookies1") + r.SetCookie("cookie2", "cookies2") + + bindReq := bindRequest{ + req: r, + } + + values := bindReq.GetCookies() + + assert.DeepEqual(t, "cookies1", values[0].Value) + assert.DeepEqual(t, "cookies2", values[1].Value) +} + +func TestGetHeader(t *testing.T) { + headers := map[string]string{ + "Header1": "headers1", + "Header2": "headers2", + } + + r := protocol.NewRequest("GET", "/foo", nil) + r.SetHeaders(headers) + r.SetHeader("Header3", "headers3") + + bindReq := bindRequest{ + req: r, + } + + values := bindReq.GetHeader() + + assert.DeepEqual(t, []string{"headers1"}, values["Header1"]) + assert.DeepEqual(t, []string{"headers2"}, values["Header2"]) + assert.DeepEqual(t, []string{"headers3"}, values["Header3"]) +} + +func TestGetMethod(t *testing.T) { + r := protocol.NewRequest("GET", "/foo", nil) + + bindReq := bindRequest{ + req: r, + } + + values := bindReq.GetMethod() + + assert.DeepEqual(t, "GET", values) +} + +func TestGetContentType(t *testing.T) { + data := "a=aaa&b=b1&b=b2&c=ccc&d=100" + mr := bytes.NewBufferString(data) + + r := protocol.NewRequest("POST", "/foo", mr) + r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) + r.Header.SetContentLength(len(data)) + + bindReq := bindRequest{ + req: r, + } + + values := bindReq.GetContentType() + + assert.DeepEqual(t, consts.MIMEApplicationHTMLForm, values) +} + +func TestGetBody(t *testing.T) { + data := "a=aaa&b=b1&b=b2&c=ccc&d=100" + mr := bytes.NewBufferString(data) + + r := protocol.NewRequest("POST", "/foo", mr) + r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) + r.Header.SetContentLength(len(data)) + + bindReq := bindRequest{ + req: r, + } + + values, err := bindReq.GetBody() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.DeepEqual(t, []byte("a=aaa&b=b1&b=b2&c=ccc&d=100"), values) +} + +func TestGetFileHeaders(t *testing.T) { + s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="f" + +fff +------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="F1"; filename="TODO1" +Content-Type: application/octet-stream + +- SessionClient with referer and cookies support. +- Client with requests' pipelining support. +- ProxyHandler similar to FSHandler. +- WebSockets. See https://tools.ietf.org/html/rfc6455 . +- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . +------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="F1"; filename="TODO2" +Content-Type: application/octet-stream + +- SessionClient with referer and cookies support. +- Client with requests' pipelining support. +- ProxyHandler similar to FSHandler. +- WebSockets. See https://tools.ietf.org/html/rfc6455 . +- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . +------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="F2"; filename="TODO3" +Content-Type: application/octet-stream + +- SessionClient with referer and cookies support. +- Client with requests' pipelining support. +- ProxyHandler similar to FSHandler. +- WebSockets. See https://tools.ietf.org/html/rfc6455 . +- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . + +------WebKitFormBoundaryJwfATyF8tmxSJnLg-- +tailfoobar` + + mr := bytes.NewBufferString(s) + + r := protocol.NewRequest("POST", "/foo", mr) + r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) + r.Header.SetContentLength(len(s)) + + bindReq := bindRequest{ + req: r, + } + + values, err := bindReq.GetFileHeaders() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + fmt.Printf("%v\n", values) + + assert.DeepEqual(t, "TODO1", values["F1"][0].Filename) + assert.DeepEqual(t, "TODO2", values["F1"][1].Filename) + assert.DeepEqual(t, "TODO3", values["F2"][0].Filename) +} diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index bc3e1a0e4..3ce91a83f 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -50,7 +50,7 @@ import ( ) func TestHertz_Run(t *testing.T) { - hertz := New(WithHostPorts("127.0.0.1:6666")) + hertz := Default(WithHostPorts("127.0.0.1:6666")) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { time.Sleep(time.Second) path := ctx.Request.URI().PathOriginal() @@ -62,6 +62,8 @@ func TestHertz_Run(t *testing.T) { atomic.StoreUint32(&testint, 1) }) + assert.Assert(t, len(hertz.Handlers) == 1) + go hertz.Spin() time.Sleep(100 * time.Millisecond) @@ -145,7 +147,7 @@ func TestHertz_GracefulShutdown(t *testing.T) { } func TestLoadHTMLGlob(t *testing.T) { - engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:8890")) + engine := New(WithMaxRequestBodySize(15), WithHostPorts("127.0.0.1:8893")) engine.Delims("{[{", "}]}") engine.LoadHTMLGlob("../../common/testdata/template/index.tmpl") engine.GET("/index", func(c context.Context, ctx *app.RequestContext) { @@ -155,7 +157,7 @@ func TestLoadHTMLGlob(t *testing.T) { }) go engine.Run() time.Sleep(200 * time.Millisecond) - resp, _ := http.Get("http://127.0.0.1:8890/index") + resp, _ := http.Get("http://127.0.0.1:8893/index") assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) n, _ := resp.Body.Read(b) @@ -690,7 +692,7 @@ type CloseWithoutResetBuffer interface { } func TestOnprepare(t *testing.T) { - h := New( + h1 := New( WithHostPorts("localhost:9229"), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { b, err := conn.Peek(3) @@ -703,42 +705,42 @@ func TestOnprepare(t *testing.T) { } return ctx })) - h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { + h1.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) - go h.Spin() + go h1.Spin() time.Sleep(time.Second) _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9229/ping") assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) - h = New( + h2 := New( WithOnAccept(func(conn net.Conn) context.Context { conn.Close() return context.Background() }), WithHostPorts("localhost:9230")) - h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { + h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) - go h.Spin() + go h2.Spin() time.Sleep(time.Second) _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9230/ping") if err == nil { t.Fatalf("err should not be nil") } - h = New( + h3 := New( WithOnAccept(func(conn net.Conn) context.Context { assert.DeepEqual(t, conn.LocalAddr().String(), "127.0.0.1:9231") return context.Background() }), WithHostPorts("localhost:9231"), WithTransport(standard.NewTransporter)) - h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { + h3.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) - go h.Spin() + go h3.Spin() time.Sleep(time.Second) c.Get(context.Background(), nil, "http://127.0.0.1:9231/ping") } diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index d9abda2c2..aef554c14 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -17,6 +17,10 @@ package server import ( + "context" + "net" + "reflect" + "syscall" "testing" "time" @@ -25,6 +29,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/network" ) func TestOptions(t *testing.T) { @@ -33,6 +38,12 @@ func TestOptions(t *testing.T) { Weight: 10, Addr: utils.NewNetAddr("tcp", ":8888"), } + cfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) {}) + }} + transporter := func(options *config.Options) network.Transporter { + return &mockTransporter{} + } opt := config.NewOptions([]config.Option{ WithReadTimeout(time.Second), WithWriteTimeout(time.Second), @@ -62,6 +73,8 @@ func TestOptions(t *testing.T) { WithTraceLevel(stats.LevelDisabled), WithRegistry(nil, info), WithAutoReloadRender(true, 5*time.Second), + WithListenConfig(cfg), + WithAltTransport(transporter), }) assert.DeepEqual(t, opt.ReadTimeout, time.Second) assert.DeepEqual(t, opt.WriteTimeout, time.Second) @@ -92,6 +105,8 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.Registry, nil) assert.DeepEqual(t, opt.AutoReloadRender, true) assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second) + assert.DeepEqual(t, opt.ListenConfig, cfg) + assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter)) } func TestDefaultOptions(t *testing.T) { @@ -125,3 +140,17 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, opt.AutoReloadRender, false) assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0)) } + +type mockTransporter struct{} + +func (m *mockTransporter) ListenAndServe(onData network.OnData) (err error) { + panic("implement me") +} + +func (m *mockTransporter) Close() error { + panic("implement me") +} + +func (m *mockTransporter) Shutdown(ctx context.Context) error { + panic("implement me") +} diff --git a/pkg/app/server/render/html_test.go b/pkg/app/server/render/html_test.go index eb8f36e96..d474f1c63 100644 --- a/pkg/app/server/render/html_test.go +++ b/pkg/app/server/render/html_test.go @@ -17,14 +17,24 @@ package render import ( + "html/template" "io/ioutil" "os" "testing" "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" ) func TestHTMLDebug_StartChecker_timer(t *testing.T) { - render := &HTMLDebug{RefreshInterval: time.Second} + render := &HTMLDebug{ + RefreshInterval: time.Second, + Delims: Delims{Left: "{[{", Right: "}]}"}, + FuncMap: template.FuncMap{}, + Files: []string{"../../../common/testdata/template/index.tmpl"}, + } select { case <-render.reloadCh: t.Fatalf("should not be triggered") @@ -35,6 +45,7 @@ func TestHTMLDebug_StartChecker_timer(t *testing.T) { case <-time.After(render.RefreshInterval + 500*time.Millisecond): t.Fatalf("should be triggered in 1.5 second") case <-render.reloadCh: + render.reload() } } @@ -64,3 +75,55 @@ func TestHTMLDebug_StartChecker_fs_watcher(t *testing.T) { default: } } + +func TestRenderHTML(t *testing.T) { + resp := &protocol.Response{} + + tmpl := template.Must(template.New(""). + Delims("{[{", "}]}"). + Funcs(template.FuncMap{}). + ParseFiles("../../../common/testdata/template/index.tmpl")) + + r := &HTMLProduction{Template: tmpl} + + html := r.Instance("index.tmpl", utils.H{ + "title": "Main website", + }) + + err := r.Close() + assert.Nil(t, err) + + html.WriteContentType(resp) + assert.DeepEqual(t, []byte("text/html; charset=utf-8"), resp.Header.Peek("Content-Type")) + + err = html.Render(resp) + + assert.Nil(t, err) + assert.DeepEqual(t, []byte("text/html; charset=utf-8"), resp.Header.Peek("Content-Type")) + assert.DeepEqual(t, []byte("

Main website

"), resp.Body()) + + respDebug := &protocol.Response{} + + rDebug := &HTMLDebug{ + Template: tmpl, + Delims: Delims{Left: "{[{", Right: "}]}"}, + FuncMap: template.FuncMap{}, + Files: []string{"../../../common/testdata/template/index.tmpl"}, + } + + htmlDebug := rDebug.Instance("index.tmpl", utils.H{ + "title": "Main website", + }) + + err = rDebug.Close() + assert.Nil(t, err) + + htmlDebug.WriteContentType(respDebug) + assert.DeepEqual(t, []byte("text/html; charset=utf-8"), respDebug.Header.Peek("Content-Type")) + + err = htmlDebug.Render(respDebug) + + assert.Nil(t, err) + assert.DeepEqual(t, []byte("text/html; charset=utf-8"), respDebug.Header.Peek("Content-Type")) + assert.DeepEqual(t, []byte("

Main website

"), respDebug.Body()) +} diff --git a/pkg/app/server/render/render_test.go b/pkg/app/server/render/render_test.go index 338d151df..0669cb48c 100644 --- a/pkg/app/server/render/render_test.go +++ b/pkg/app/server/render/render_test.go @@ -47,6 +47,7 @@ import ( "github.com/bytedance/sonic" "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/common/testdata/proto" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -96,8 +97,59 @@ func TestRenderJSONError(t *testing.T) { resp := &protocol.Response{} data := make(chan int) + err := (JSONRender{data}).Render(resp) + // json: unsupported type: chan int + assert.NotNil(t, err) +} + +func TestRenderPureJSON(t *testing.T) { + resp := &protocol.Response{} + data := map[string]interface{}{ + "foo": "bar", + "html": "", + } + + (PureJSON{data}).WriteContentType(resp) + assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) + + err := (PureJSON{data}).Render(resp) + + assert.Nil(t, err) + + assert.DeepEqual(t, []byte("{\"foo\":\"bar\",\"html\":\"\"}\n"), resp.Body()) + assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) +} + +func TestRenderPureJSONError(t *testing.T) { + resp := &protocol.Response{} + data := make(chan int) + + err := (PureJSON{data}).Render(resp) // json: unsupported type: chan int - assert.NotNil(t, func() { (JSONRender{data}).Render(resp) }) + assert.NotNil(t, err) +} + +func TestRenderProtobuf(t *testing.T) { + resp := &protocol.Response{} + data := proto.TestStruct{Body: []byte("Hello World")} + + (ProtoBuf{&data}).WriteContentType(resp) + assert.DeepEqual(t, []byte("application/x-protobuf"), resp.Header.Peek("Content-Type")) + + err := (ProtoBuf{&data}).Render(resp) + + assert.Nil(t, err) + assert.DeepEqual(t, []byte("\n\vHello World"), resp.Body()) + assert.DeepEqual(t, []byte("application/x-protobuf"), resp.Header.Peek("Content-Type")) +} + +func TestRenderProtobufError(t *testing.T) { + resp := &protocol.Response{} + data := proto.Test{} + + err := (ProtoBuf{&data}).Render(resp) + + assert.NotNil(t, err) } func TestRenderString(t *testing.T) { @@ -162,6 +214,15 @@ func TestRenderXML(t *testing.T) { assert.DeepEqual(t, []byte(consts.MIMEApplicationXMLUTF8), resp.Header.Peek("Content-Type")) } +func TestRenderXMLError(t *testing.T) { + resp := &protocol.Response{} + data := make(chan int) + + err := (XML{data}).Render(resp) + + assert.NotNil(t, err) +} + func TestRenderIndentedJSON(t *testing.T) { data := map[string]interface{}{ "foo": "bar", diff --git a/pkg/common/test/mock/reader.go b/pkg/common/test/mock/reader.go index 7a352653d..b67d50354 100644 --- a/pkg/common/test/mock/reader.go +++ b/pkg/common/test/mock/reader.go @@ -65,6 +65,13 @@ func NewZeroCopyReader(r string) ZeroCopyReader { return ZeroCopyReader{br} } +func NewLimitReader(r *bytes.Buffer) io.LimitedReader { + return io.LimitedReader{ + R: r, + N: int64(r.Len()), + } +} + type EOFReader struct{} func (e *EOFReader) Peek(n int) ([]byte, error) { diff --git a/pkg/common/utils/chunk.go b/pkg/common/utils/chunk.go index 6fa11d03d..92c819426 100644 --- a/pkg/common/utils/chunk.go +++ b/pkg/common/utils/chunk.go @@ -65,9 +65,9 @@ func ParseChunkSize(r network.Reader) (int, error) { return n, nil } +// SkipCRLF will only skip the next CRLF("\r\n"), otherwise, error will be returned. func SkipCRLF(reader network.Reader) error { p, err := reader.Peek(len(bytestr.StrCRLF)) - reader.Skip(len(p)) // nolint: errcheck if err != nil { return err } @@ -75,5 +75,6 @@ func SkipCRLF(reader network.Reader) error { return errBrokenChunk } + reader.Skip(len(p)) // nolint: errcheck return nil } diff --git a/pkg/common/utils/chunk_test.go b/pkg/common/utils/chunk_test.go index ffeafa163..d9c8570b0 100644 --- a/pkg/common/utils/chunk_test.go +++ b/pkg/common/utils/chunk_test.go @@ -35,6 +35,33 @@ func TestChunkParseChunkSizeGetCorrect(t *testing.T) { } } +func TestChunkParseChunkSizeGetError(t *testing.T) { + // test err from -----n, err := bytesconv.ReadHexInt(r)----- + chunkSizeBody := "" + zr := mock.NewZeroCopyReader(chunkSizeBody) + chunkSize, err := ParseChunkSize(zr) + assert.NotNil(t, err) + assert.DeepEqual(t, -1, chunkSize) + // test err from -----c, err := r.ReadByte()----- + chunkSizeBody = "0" + zr = mock.NewZeroCopyReader(chunkSizeBody) + chunkSize, err = ParseChunkSize(zr) + assert.NotNil(t, err) + assert.DeepEqual(t, -1, chunkSize) + // test err from -----c, err := r.ReadByte()----- + chunkSizeBody = "0" + "\r" + zr = mock.NewZeroCopyReader(chunkSizeBody) + chunkSize, err = ParseChunkSize(zr) + assert.NotNil(t, err) + assert.DeepEqual(t, -1, chunkSize) + // test err from -----c, err := r.ReadByte()----- + chunkSizeBody = "0" + "\r" + "\r" + zr = mock.NewZeroCopyReader(chunkSizeBody) + chunkSize, err = ParseChunkSize(zr) + assert.NotNil(t, err) + assert.DeepEqual(t, -1, chunkSize) +} + func TestChunkParseChunkSizeCorrectWhiteSpace(t *testing.T) { // test the whitespace whiteSpace := "" diff --git a/pkg/common/utils/ioutil_test.go b/pkg/common/utils/ioutil_test.go index c0d020c7e..c47718014 100644 --- a/pkg/common/utils/ioutil_test.go +++ b/pkg/common/utils/ioutil_test.go @@ -18,12 +18,81 @@ package utils import ( "bytes" + "io" "testing" + "github.com/cloudwego/hertz/pkg/common/test/mock" + "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" ) +type writeReadTest interface { + Write(p []byte) (n int, err error) + Malloc(n int) (buf []byte, err error) + WriteBinary(b []byte) (n int, err error) + Flush() error +} + +type readerTest interface { + ReadFrom(r io.Reader) (n int64, err error) + Malloc(n int) (buf []byte, err error) + WriteBinary(b []byte) (n int, err error) + Flush() error +} + +type testWriter struct { + w io.Writer +} + +func (t testWriter) Write(p []byte) (n int, err error) { + return +} + +func (t testWriter) Malloc(n int) (buf []byte, err error) { + return +} + +func (t testWriter) WriteBinary(b []byte) (n int, err error) { + return +} + +func (t testWriter) Flush() error { + return nil +} + +type testReader struct { + r io.ReaderFrom +} + +func (t testReader) ReadFrom(r io.Reader) (n int64, err error) { + return +} + +func (t testReader) Malloc(n int) (buf []byte, err error) { + return +} + +func (t testReader) WriteBinary(b []byte) (n int, err error) { + return +} + +func (t testReader) Flush() error { + return nil +} + +func newTestWriter(w io.Writer) writeReadTest { + return &testWriter{ + w: w, + } +} + +func newTestReaderForm(r io.ReaderFrom) readerTest { + return &testReader{ + r: r, + } +} + func TestIoutilCopyBuffer(t *testing.T) { var writeBuffer bytes.Buffer str := string("hertz is very good!!!") @@ -39,6 +108,53 @@ func TestIoutilCopyBuffer(t *testing.T) { assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } +func TestIoutilCopyBufferWithIoWriter(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + var buf []byte + src := bytes.NewBuffer([]byte(str)) + ioWriter := newTestWriter(&writeBuffer) + // to show example about -----w, ok := dst.(io.Writer)----- + _, ok := ioWriter.(io.Writer) + assert.DeepEqual(t, true, ok) + written, err := CopyBuffer(ioWriter, src, buf) + assert.DeepEqual(t, written, int64(0)) + assert.NotNil(t, err) + assert.DeepEqual(t, []byte(nil), writeBuffer.Bytes()) +} + +func TestIoutilCopyBufferWithIoReaderFrom(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + var buf []byte + src := bytes.NewBufferString(str) + ioReaderFrom := newTestReaderForm(&writeBuffer) + // to show example about -----rf, ok := dst.(io.ReaderFrom)----- + _, ok := ioReaderFrom.(io.Writer) + assert.DeepEqual(t, false, ok) + _, ok = ioReaderFrom.(io.ReaderFrom) + assert.DeepEqual(t, true, ok) + written, err := CopyBuffer(ioReaderFrom, src, buf) + assert.DeepEqual(t, written, int64(0)) + assert.NotNil(t, err) + assert.DeepEqual(t, []byte(nil), writeBuffer.Bytes()) +} + +func TestIoutilCopyBufferWithPanic(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + var buf []byte + defer func() { + if r := recover(); r != nil { + assert.DeepEqual(t, "empty buffer in io.CopyBuffer", r) + } + }() + src := bytes.NewBufferString(str) + dst := network.NewWriter(&writeBuffer) + buf = make([]byte, 0) + _, _ = CopyBuffer(dst, src, buf) +} + func TestIoutilCopyBufferWithNilBuffer(t *testing.T) { var writeBuffer bytes.Buffer str := string("hertz is very good!!!") @@ -49,7 +165,34 @@ func TestIoutilCopyBufferWithNilBuffer(t *testing.T) { written, err := CopyBuffer(dst, src, nil) assert.DeepEqual(t, written, srcLen) - assert.DeepEqual(t, err, nil) + assert.NotNil(t, err) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) +} + +func TestIoutilCopyBufferWithNilBufferAndIoLimitedReader(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + src := bytes.NewBufferString(str) + reader := mock.NewLimitReader(src) + dst := network.NewWriter(&writeBuffer) + srcLen := int64(src.Len()) + written, err := CopyBuffer(dst, &reader, nil) + + assert.DeepEqual(t, written, srcLen) + assert.NotNil(t, err) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) + + // test l.N < 1 + writeBuffer.Reset() + str = "" + src = bytes.NewBufferString(str) + reader = mock.NewLimitReader(src) + dst = network.NewWriter(&writeBuffer) + srcLen = int64(src.Len()) + written, err = CopyBuffer(dst, &reader, nil) + + assert.DeepEqual(t, written, srcLen) + assert.NotNil(t, err) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } diff --git a/pkg/protocol/http1/ext/common.go b/pkg/protocol/http1/ext/common.go index 864988317..2f2280ab9 100644 --- a/pkg/protocol/http1/ext/common.go +++ b/pkg/protocol/http1/ext/common.go @@ -51,6 +51,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" + "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" @@ -117,12 +118,17 @@ func WriteBodyChunked(w network.Writer, r io.Reader) error { if err == nil { panic("BUG: io.Reader returned 0, nil") } - if err == io.EOF { - if err = WriteChunk(w, buf[:0], true); err != nil { - break - } - err = nil + + if !errors.Is(err, io.EOF) { + hlog.SystemLogger().Warnf("writing chunked response body encountered an error from the reader, "+ + "this may cause the short of the content in response body, error: %s", err.Error()) + } + + if err = WriteChunk(w, buf[:0], true); err != nil { + break } + + err = nil break } if err = WriteChunk(w, buf[:n], true); err != nil { diff --git a/pkg/protocol/http1/ext/common_test.go b/pkg/protocol/http1/ext/common_test.go index 78a9fede0..9bc1936e3 100644 --- a/pkg/protocol/http1/ext/common_test.go +++ b/pkg/protocol/http1/ext/common_test.go @@ -24,6 +24,7 @@ import ( "testing" errs "github.com/cloudwego/hertz/pkg/common/errors" + "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" @@ -123,6 +124,9 @@ func TestReadRawHeaders(t *testing.T) { } func TestBodyChunked(t *testing.T) { + var log bytes.Buffer + hlog.SetOutput(&log) + body := "foobar baz aaa bbb ccc" chunk := "16\r\nfoobar baz aaa bbb ccc\r\n0\r\n" b := bytes.NewBufferString(body) @@ -137,6 +141,22 @@ func TestBodyChunked(t *testing.T) { rb, err := ReadBody(zr, -1, 0, nil) assert.Nil(t, err) assert.DeepEqual(t, body, string(rb)) + + assert.DeepEqual(t, 0, log.Len()) +} + +func TestBrokenBodyChunked(t *testing.T) { + brokenReader := mock.NewBrokenConn("") + var log bytes.Buffer + hlog.SetOutput(&log) + + var w bytes.Buffer + zw := netpoll.NewWriter(&w) + err := WriteBodyChunked(zw, brokenReader) + assert.Nil(t, err) + + assert.DeepEqual(t, []byte("0\r\n"), w.Bytes()) + assert.True(t, bytes.Contains(log.Bytes(), []byte("writing chunked response body encountered an error from the reader"))) } func TestBodyFixedSize(t *testing.T) { diff --git a/pkg/protocol/http1/req/header_test.go b/pkg/protocol/http1/req/header_test.go index 489e0aabf..d0978b3c4 100644 --- a/pkg/protocol/http1/req/header_test.go +++ b/pkg/protocol/http1/req/header_test.go @@ -421,3 +421,117 @@ func TestRequestHeaderError(t *testing.T) { err := ReadHeader(&rh, &er) assert.True(t, errors.Is(err, errs.ErrNothingRead)) } + +func TestReadHeader(t *testing.T) { + s := "P" + zr := mock.NewZeroCopyReader(s) + rh := protocol.RequestHeader{} + err := ReadHeader(&rh, zr) + assert.NotNil(t, err) +} + +func TestParseHeaders(t *testing.T) { + rh := protocol.RequestHeader{} + _, err := parseHeaders(&rh, []byte{' '}) + assert.NotNil(t, err) +} + +func TestTryRead(t *testing.T) { + rh := protocol.RequestHeader{} + s := "P" + zr := mock.NewZeroCopyReader(s) + err := tryRead(&rh, zr, 0) + assert.NotNil(t, err) +} + +func TestParseFirstLine(t *testing.T) { + tests := []struct { + input []byte + method string + uri string + protocol string + err error + }{ + // Test case 1: n < 0 + { + input: []byte("GET /path/to/resource HTTP/1.0\r\n"), + method: "GET", + uri: "/path/to/resource", + protocol: "HTTP/1.0", + err: nil, + }, + // Test case 2: n == 0 + { + input: []byte(" /path/to/resource HTTP/1.1\r\n"), + method: "", + uri: "", + protocol: "", + err: fmt.Errorf("requestURI cannot be empty in"), + }, + // Test case 3: !bytes.Equal(b[n+1:], bytestr.StrHTTP11) + { + input: []byte("POST /path/to/resource HTTP/1.2\r\n"), + method: "POST", + uri: "/path/to/resource", + protocol: "HTTP/1.0", + err: nil, + }, + } + + for _, tc := range tests { + header := &protocol.RequestHeader{} + _, err := parseFirstLine(header, tc.input) + assert.NotNil(t, err) + } +} + +func TestParse(t *testing.T) { + tests := []struct { + name string + input []byte + expected int + wantErr bool + }{ + // normal test + { + name: "normal", + input: []byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\n\r\n"), + expected: len([]byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\n\r\n")), + wantErr: false, + }, + // parseFirstLine error + { + name: "parseFirstLine error", + input: []byte("INVALID_LINE\r\nHost: example.com\r\n\r\n"), + expected: 0, + wantErr: true, + }, + // ext.ReadRawHeaders error + { + name: "ext.ReadRawHeaders error", + input: []byte("GET /path/to/resource HTTP/1.1\r\nINVALID_HEADER\r\n\r\n"), + expected: 0, + wantErr: true, + }, + // parseHeaders error + { + name: "parseHeaders error", + input: []byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\nINVALID_HEADER\r\n"), + expected: 0, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + header := &protocol.RequestHeader{} + bytesRead, err := parse(header, tc.input) + if (err != nil) != tc.wantErr { + t.Errorf("Expected error: %v, but got: %v", tc.wantErr, err) + } + if bytesRead != tc.expected { + t.Errorf("Expected bytes read: %d, but got: %d", tc.expected, bytesRead) + } + }) + } +} diff --git a/pkg/protocol/http1/resp/writer.go b/pkg/protocol/http1/resp/writer.go index 745ab4c5c..7b50cd434 100644 --- a/pkg/protocol/http1/resp/writer.go +++ b/pkg/protocol/http1/resp/writer.go @@ -17,6 +17,7 @@ package resp import ( + "runtime" "sync" "github.com/cloudwego/hertz/pkg/network" @@ -24,6 +25,16 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) +var chunkReaderPool sync.Pool + +func init() { + chunkReaderPool = sync.Pool{ + New: func() interface{} { + return &chunkedBodyWriter{} + }, + } +} + type chunkedBodyWriter struct { sync.Once finalizeErr error @@ -59,6 +70,14 @@ func (c *chunkedBodyWriter) Flush() error { // Warning: do not call this method by yourself, unless you know what you are doing. func (c *chunkedBodyWriter) Finalize() error { c.Do(func() { + // in case no actual data from user + if !c.wroteHeader { + c.r.Header.SetContentLength(-1) + if c.finalizeErr = WriteHeader(&c.r.Header, c.w); c.finalizeErr != nil { + return + } + c.wroteHeader = true + } c.finalizeErr = ext.WriteChunk(c.w, nil, true) if c.finalizeErr != nil { return @@ -68,9 +87,19 @@ func (c *chunkedBodyWriter) Finalize() error { return c.finalizeErr } +func (c *chunkedBodyWriter) release() { + c.r = nil + c.w = nil + c.finalizeErr = nil + c.wroteHeader = false + chunkReaderPool.Put(c) +} + func NewChunkedBodyWriter(r *protocol.Response, w network.Writer) network.ExtWriter { - return &chunkedBodyWriter{ - r: r, - w: w, - } + extWriter := chunkReaderPool.Get().(*chunkedBodyWriter) + extWriter.r = r + extWriter.w = w + extWriter.Once = sync.Once{} + runtime.SetFinalizer(extWriter, (*chunkedBodyWriter).release) + return extWriter } diff --git a/pkg/protocol/http1/resp/writer_test.go b/pkg/protocol/http1/resp/writer_test.go index bca2d43a1..b57281ae3 100644 --- a/pkg/protocol/http1/resp/writer_test.go +++ b/pkg/protocol/http1/resp/writer_test.go @@ -51,3 +51,16 @@ func TestNewChunkedBodyWriter1(t *testing.T) { assert.True(t, strings.Contains(string(out), "5"+string(bytestr.StrCRLF)+"hello")) assert.True(t, strings.Contains(string(out), "0"+string(bytestr.StrCRLF)+string(bytestr.StrCRLF))) } + +func TestNewChunkedBodyWriterNoData(t *testing.T) { + response := protocol.AcquireResponse() + response.Header.Set("Foo", "Bar") + mockConn := mock.NewConn("") + w := NewChunkedBodyWriter(response, mockConn) + w.Finalize() + w.Flush() + out, _ := mockConn.WriterRecorder().ReadBinary(mockConn.WriterRecorder().WroteLen()) + assert.True(t, strings.Contains(string(out), "Transfer-Encoding: chunked")) + assert.True(t, strings.Contains(string(out), "Foo: Bar")) + assert.True(t, strings.Contains(string(out), "0"+string(bytestr.StrCRLF)+string(bytestr.StrCRLF))) +} diff --git a/pkg/protocol/multipart_test.go b/pkg/protocol/multipart_test.go index 86995d28e..6b96ae86f 100644 --- a/pkg/protocol/multipart_test.go +++ b/pkg/protocol/multipart_test.go @@ -44,6 +44,7 @@ package protocol import ( "bytes" "mime/multipart" + "net/textproto" "os" "strings" "testing" @@ -80,6 +81,15 @@ Content-Type: application/json err = WriteMultipartForm(&w, form, "") }) + // call WriteField as twice + var body bytes.Buffer + mw := multipart.NewWriter(&body) + if err = mw.WriteField("field1", "value1"); err != nil { + t.Fatal(err) + } + err = WriteMultipartForm(&w, form, s) + assert.NotNil(t, err) + // normal test err = WriteMultipartForm(&w, form, "foo") if err != nil { @@ -238,3 +248,28 @@ Content-Type: application/json _, err = MarshalMultipartForm(form, " ") assert.NotNil(t, err) } + +func TestAddFile(t *testing.T) { + t.Parallel() + bodyBuffer := &bytes.Buffer{} + w := multipart.NewWriter(bodyBuffer) + // add null file + err := AddFile(w, "test", "/test") + assert.NotNil(t, err) +} + +func TestCreateMultipartHeader(t *testing.T) { + t.Parallel() + + // filename == Null + hdr1 := make(textproto.MIMEHeader) + hdr1.Set("Content-Disposition", `form-data; name="test"`) + hdr1.Set("Content-Type", "application/json") + assert.DeepEqual(t, hdr1, CreateMultipartHeader("test", "", "application/json")) + + // normal test + hdr2 := make(textproto.MIMEHeader) + hdr2.Set("Content-Disposition", `form-data; name="test"; filename="/test.go"`) + hdr2.Set("Content-Type", "application/json") + assert.DeepEqual(t, hdr2, CreateMultipartHeader("test", "/test.go", "application/json")) +} diff --git a/version.go b/version.go index 2f7a82fb5..279360628 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.6.7" + Version = "v0.6.8" )