diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go index bc52efec4..116df93cc 100644 --- a/cmd/hz/app/app.go +++ b/cmd/hz/app/app.go @@ -179,6 +179,7 @@ func Init() *cli.App { noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse} forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew} enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends} + sortRouterFlag := cli.BoolFlag{Name: "sort_router", Usage: "Sort router register code, to avoid code difference", Destination: &globalArgs.SortRouter} jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr} queryEnumIntFlag := cli.BoolFlag{Name: "query_enumint", Usage: "Use num instead of string for query enum parameter.", Destination: &globalArgs.QueryEnumAsInt} @@ -227,6 +228,7 @@ func Init() *cli.App { &noRecurseFlag, &forceNewFlag, &enableExtendsFlag, + &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, @@ -261,6 +263,7 @@ func Init() *cli.App { &optPkgFlag, &noRecurseFlag, &enableExtendsFlag, + &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, diff --git a/cmd/hz/config/argument.go b/cmd/hz/config/argument.go index 2e208d26b..4e6d75ed5 100644 --- a/cmd/hz/config/argument.go +++ b/cmd/hz/config/argument.go @@ -72,6 +72,7 @@ type Argument struct { ForceNew bool SnakeStyleMiddleware bool EnableExtends bool + SortRouter bool CustomizeLayout string CustomizeLayoutData string diff --git a/cmd/hz/generator/handler.go b/cmd/hz/generator/handler.go index 609e1d6cf..eeab0cf7a 100644 --- a/cmd/hz/generator/handler.go +++ b/cmd/hz/generator/handler.go @@ -161,7 +161,7 @@ func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *Route } handler.Imports[mm.PackageName] = mm } - err := root.Update(m, handler.PackageName, singleHandlerPackage) + err := root.Update(m, handler.PackageName, singleHandlerPackage, pkgGen.SortRouter) if err != nil { return err } diff --git a/cmd/hz/generator/package.go b/cmd/hz/generator/package.go index 22e31132d..db4d9762b 100644 --- a/cmd/hz/generator/package.go +++ b/cmd/hz/generator/package.go @@ -63,12 +63,13 @@ type HttpPackageGenerator struct { IdlClientDir string // client dir for "client" command ForceClientDir string // client dir without namespace for "client" command BaseDomain string // request domain for "client" command - QueryEnumAsInt bool // client code use number for query parameter + QueryEnumAsInt bool // client code use number for query parameter ServiceGenDir string NeedModel bool HandlerByMethod bool // generate handler files with method dimension SnakeStyleMiddleware bool // use snake name style for middleware + SortRouter bool loadedBackend Backend curModel *model.Model diff --git a/cmd/hz/generator/router.go b/cmd/hz/generator/router.go index 10f431657..8a5f4b315 100644 --- a/cmd/hz/generator/router.go +++ b/cmd/hz/generator/router.go @@ -20,10 +20,13 @@ import ( "bytes" "fmt" "io/ioutil" + "math" "path/filepath" "regexp" "sort" + "strconv" "strings" + "unicode" "github.com/cloudwego/hertz/cmd/hz/util" ) @@ -73,7 +76,7 @@ func (routerNode *RouterNode) Sort() { sort.Sort(routerNode.Children) } -func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string) error { +func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string, sortRouter bool) error { if method.Path == "" { return fmt.Errorf("empty path for method '%s'", method.Name) } @@ -81,12 +84,12 @@ func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg if paths[0] == "" { paths = paths[1:] } - parent, last := routerNode.FindNearest(paths) + parent, last := routerNode.FindNearest(paths, method.HTTPMethod) if last == len(paths) { return fmt.Errorf("path '%s' has been registered", method.Path) } name := util.ToVarName(paths[:last]) - parent.Insert(name, method, handlerType, paths[last:], handlerPkg) + parent.Insert(name, method, handlerType, paths[last:], handlerPkg, sortRouter) parent.Sort() return nil } @@ -192,7 +195,7 @@ func (routerNode *RouterNode) DFS(i int, hook func(layer int, node *RouterNode) var handlerPkgMap map[string]string -func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string) { +func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string, sortRouter bool) { cur := routerNode for i, p := range paths { c := &RouterNode{ @@ -229,6 +232,9 @@ func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerTyp cur.Children = make([]*RouterNode, 0, 1) } cur.Children = append(cur.Children, c) + if sortRouter { + sort.Sort(cur.Children) + } cur = c } } @@ -240,14 +246,18 @@ func getHttpMethod(method string) string { return strings.ToUpper(method) } -func (routerNode *RouterNode) FindNearest(paths []string) (*RouterNode, int) { +func (routerNode *RouterNode) FindNearest(paths []string, method string) (*RouterNode, int) { ns := len(paths) cur := routerNode i := 0 path := paths[i] for j := 0; j < len(cur.Children); j++ { c := cur.Children[j] - if ("/" + path) == c.Path { + tmpMethod := "" + if i == ns { // group do not have http method + tmpMethod = method + } + if ("/"+path) == c.Path && strings.EqualFold(c.HttpMethod, tmpMethod) { i++ if i == ns { return cur, i - 1 @@ -270,17 +280,35 @@ func (c childrenRouterInfo) Len() int { // Less reports whether the element with // index i should sort before the element with index j. func (c childrenRouterInfo) Less(i, j int) bool { - ci := c[i].Path - if len(c[i].Children) != 0 { - ci = ci[1:] + if c[i].HttpMethod == "" && c[j].HttpMethod != "" { + return false } - cj := c[j].Path - if len(c[j].Children) != 0 { - cj = cj[1:] + if c[i].HttpMethod != "" && c[j].HttpMethod == "" { + return true } + // remove non-litter char + // eg. /a -> a + // /:a -> a + ci := removeNonLetterPrefix(c[i].Path) + cj := removeNonLetterPrefix(c[j].Path) + + // if ci == cj, use HTTP mothod for sort, preventing sorting inconsistencies + if ci == cj { + return c[i].HttpMethod < c[j].HttpMethod + } + return ci < cj } +func removeNonLetterPrefix(str string) string { + for i, char := range str { + if unicode.IsLetter(char) || unicode.IsDigit(char) { + return str[i:] + } + } + return str +} + // Swap swaps the elements with indexes i and j. func (c childrenRouterInfo) Swap(i, j int) { c[i], c[j] = c[j], c[i] @@ -341,6 +369,29 @@ func (pkgGen *HttpPackageGenerator) updateRegister(pkg, rDir, pkgName string) er return nil } +func appendMw(mws []string, mw string) ([]string, string) { + for i := 0; true; i++ { + if i == math.MaxInt { + break + } + if !stringsIncludes(mws, mw) { + mws = append(mws, mw) + break + } + mw += strconv.Itoa(i) + } + return mws, mw +} + +func stringsIncludes(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} + func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode, handlerPackage, routerDir, routerPackage string) error { err := root.DyeGroupName(pkgGen.SnakeStyleMiddleware) if err != nil { @@ -367,6 +418,31 @@ func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode router.HandlerPackages = handlerMap } + if pkgGen.SnakeStyleMiddleware { // unique middleware name for SnakeStyleMiddleware + mws := []string{} + hook := func(layer int, node *RouterNode) error { + if len(node.Children) == 0 { + return nil + } + groupMwName := node.GroupMiddleware + handlerMwName := node.HandlerMiddleware + if len(groupMwName) != 0 { + mws, groupMwName = appendMw(mws, groupMwName) + } + if len(handlerMwName) != 0 { + mws, handlerMwName = appendMw(mws, handlerMwName) + } + if groupMwName != node.GroupMiddleware { + node.GroupMiddleware = groupMwName + } + if handlerMwName != node.HandlerMiddleware { + node.HandlerMiddleware = handlerMwName + } + return nil + } + root.DFS(0, hook) + } + // store router info pkg.RouterInfo = &router diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go index 95d2473e1..d6f775cef 100644 --- a/cmd/hz/protobuf/plugin.go +++ b/cmd/hz/protobuf/plugin.go @@ -621,6 +621,7 @@ func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps BaseDomain: args.BaseDomain, QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { diff --git a/cmd/hz/thrift/plugin.go b/cmd/hz/thrift/plugin.go index 956a0a6ec..a9a628dc9 100644 --- a/cmd/hz/thrift/plugin.go +++ b/cmd/hz/thrift/plugin.go @@ -152,6 +152,7 @@ func (plugin *Plugin) Run() int { BaseDomain: args.BaseDomain, QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) @@ -232,6 +233,10 @@ func (plugin *Plugin) handleRequest() error { thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc()) thriftgoUtil.HandleOptions(req.GeneratorParameters) + // TEST: For plugin debug. Delete below codes after release! + buf, _ := thriftgo_plugin.MarshalRequest(req) + err = os.MkdirAll("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data", os.FileMode(0755)) + err = os.WriteFile("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data/request_thrift.out", buf, os.FileMode(0755)) return nil } diff --git a/cmd/hz/thrift/plugin_test.go b/cmd/hz/thrift/plugin_test.go index 5e471b67e..13d44b48e 100644 --- a/cmd/hz/thrift/plugin_test.go +++ b/cmd/hz/thrift/plugin_test.go @@ -17,6 +17,8 @@ package thrift import ( + "github.com/cloudwego/thriftgo/generator/backend" + "github.com/cloudwego/thriftgo/generator/golang" "io/ioutil" "testing" @@ -26,7 +28,7 @@ import ( ) func TestRun(t *testing.T) { - data, err := ioutil.ReadFile("../testdata/request_thrift.out") + data, err := ioutil.ReadFile("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data/request_thrift.out") if err != nil { t.Fatal(err) } @@ -35,6 +37,8 @@ func TestRun(t *testing.T) { if err != nil { t.Fatal(err) } + thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc()) + thriftgoUtil.HandleOptions(req.GeneratorParameters) plu := new(Plugin) plu.setLogger() @@ -79,12 +83,21 @@ func TestRun(t *testing.T) { HandlerDir: handlerDir, RouterDir: routerDir, ModelDir: modelDir, + UseDir: args.Use, ClientDir: clientDir, TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, + Excludes: args.Excludes, }, - ProjPackage: pkg, - Options: options, + ProjPackage: pkg, + Options: options, + HandlerByMethod: args.HandlerByMethod, + CmdType: args.CmdType, + ForceClientDir: args.ForceClientDir, + BaseDomain: args.BaseDomain, + QueryEnumAsInt: args.QueryEnumAsInt, + SnakeStyleMiddleware: args.SnakeStyleMiddleware, + SortRouter: args.SortRouter, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) diff --git a/cmd/hz/thrift/test_data/request_thrift.out b/cmd/hz/thrift/test_data/request_thrift.out new file mode 100755 index 000000000..cfc827dfa Binary files /dev/null and b/cmd/hz/thrift/test_data/request_thrift.out differ