Skip to content

Commit

Permalink
optimize: router_sort
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed Apr 24, 2024
1 parent bd2a9b2 commit 7c2841f
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 17 deletions.
3 changes: 3 additions & 0 deletions cmd/hz/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func Init() *cli.App {
noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse}
forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew}
enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends}
sortRouterFlag := cli.BoolFlag{Name: "sort_router", Usage: "Sort router register code, to avoid code difference", Destination: &globalArgs.SortRouter}

jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr}
queryEnumIntFlag := cli.BoolFlag{Name: "query_enumint", Usage: "Use num instead of string for query enum parameter.", Destination: &globalArgs.QueryEnumAsInt}
Expand Down Expand Up @@ -227,6 +228,7 @@ func Init() *cli.App {
&noRecurseFlag,
&forceNewFlag,
&enableExtendsFlag,
&sortRouterFlag,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down Expand Up @@ -261,6 +263,7 @@ func Init() *cli.App {
&optPkgFlag,
&noRecurseFlag,
&enableExtendsFlag,
&sortRouterFlag,

&jsonEnumStrFlag,
&unsetOmitemptyFlag,
Expand Down
1 change: 1 addition & 0 deletions cmd/hz/config/argument.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type Argument struct {
ForceNew bool
SnakeStyleMiddleware bool
EnableExtends bool
SortRouter bool

CustomizeLayout string
CustomizeLayoutData string
Expand Down
2 changes: 1 addition & 1 deletion cmd/hz/generator/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *Route
}
handler.Imports[mm.PackageName] = mm
}
err := root.Update(m, handler.PackageName, singleHandlerPackage)
err := root.Update(m, handler.PackageName, singleHandlerPackage, pkgGen.SortRouter)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion cmd/hz/generator/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ type HttpPackageGenerator struct {
IdlClientDir string // client dir for "client" command
ForceClientDir string // client dir without namespace for "client" command
BaseDomain string // request domain for "client" command
QueryEnumAsInt bool // client code use number for query parameter
QueryEnumAsInt bool // client code use number for query parameter
ServiceGenDir string

NeedModel bool
HandlerByMethod bool // generate handler files with method dimension
SnakeStyleMiddleware bool // use snake name style for middleware
SortRouter bool

loadedBackend Backend
curModel *model.Model
Expand Down
100 changes: 88 additions & 12 deletions cmd/hz/generator/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import (
"bytes"
"fmt"
"io/ioutil"
"math"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"unicode"

"github.com/cloudwego/hertz/cmd/hz/util"
)
Expand Down Expand Up @@ -73,20 +76,20 @@ func (routerNode *RouterNode) Sort() {
sort.Sort(routerNode.Children)
}

func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string) error {
func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string, sortRouter bool) error {
if method.Path == "" {
return fmt.Errorf("empty path for method '%s'", method.Name)
}
paths := strings.Split(method.Path, "/")
if paths[0] == "" {
paths = paths[1:]
}
parent, last := routerNode.FindNearest(paths)
parent, last := routerNode.FindNearest(paths, method.HTTPMethod)
if last == len(paths) {
return fmt.Errorf("path '%s' has been registered", method.Path)
}
name := util.ToVarName(paths[:last])
parent.Insert(name, method, handlerType, paths[last:], handlerPkg)
parent.Insert(name, method, handlerType, paths[last:], handlerPkg, sortRouter)
parent.Sort()
return nil
}
Expand Down Expand Up @@ -192,7 +195,7 @@ func (routerNode *RouterNode) DFS(i int, hook func(layer int, node *RouterNode)

var handlerPkgMap map[string]string

func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string) {
func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string, sortRouter bool) {
cur := routerNode
for i, p := range paths {
c := &RouterNode{
Expand Down Expand Up @@ -229,6 +232,9 @@ func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerTyp
cur.Children = make([]*RouterNode, 0, 1)
}
cur.Children = append(cur.Children, c)
if sortRouter {
sort.Sort(cur.Children)
}
cur = c
}
}
Expand All @@ -240,14 +246,18 @@ func getHttpMethod(method string) string {
return strings.ToUpper(method)
}

func (routerNode *RouterNode) FindNearest(paths []string) (*RouterNode, int) {
func (routerNode *RouterNode) FindNearest(paths []string, method string) (*RouterNode, int) {
ns := len(paths)
cur := routerNode
i := 0
path := paths[i]
for j := 0; j < len(cur.Children); j++ {
c := cur.Children[j]
if ("/" + path) == c.Path {
tmpMethod := ""
if i == ns { // group do not have http method
tmpMethod = method
}
if ("/"+path) == c.Path && strings.EqualFold(c.HttpMethod, tmpMethod) {
i++
if i == ns {
return cur, i - 1
Expand All @@ -270,17 +280,35 @@ func (c childrenRouterInfo) Len() int {
// Less reports whether the element with
// index i should sort before the element with index j.
func (c childrenRouterInfo) Less(i, j int) bool {
ci := c[i].Path
if len(c[i].Children) != 0 {
ci = ci[1:]
if c[i].HttpMethod == "" && c[j].HttpMethod != "" {
return false
}
cj := c[j].Path
if len(c[j].Children) != 0 {
cj = cj[1:]
if c[i].HttpMethod != "" && c[j].HttpMethod == "" {
return true
}
// remove non-litter char
// eg. /a -> a
// /:a -> a
ci := removeNonLetterPrefix(c[i].Path)
cj := removeNonLetterPrefix(c[j].Path)

// if ci == cj, use HTTP mothod for sort, preventing sorting inconsistencies
if ci == cj {
return c[i].HttpMethod < c[j].HttpMethod
}

return ci < cj
}

func removeNonLetterPrefix(str string) string {
for i, char := range str {
if unicode.IsLetter(char) || unicode.IsDigit(char) {
return str[i:]
}
}
return str
}

// Swap swaps the elements with indexes i and j.
func (c childrenRouterInfo) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
Expand Down Expand Up @@ -341,6 +369,29 @@ func (pkgGen *HttpPackageGenerator) updateRegister(pkg, rDir, pkgName string) er
return nil
}

func appendMw(mws []string, mw string) ([]string, string) {
for i := 0; true; i++ {
if i == math.MaxInt {
break
}
if !stringsIncludes(mws, mw) {
mws = append(mws, mw)
break
}
mw += strconv.Itoa(i)
}
return mws, mw
}

func stringsIncludes(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}

func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode, handlerPackage, routerDir, routerPackage string) error {
err := root.DyeGroupName(pkgGen.SnakeStyleMiddleware)
if err != nil {
Expand All @@ -367,6 +418,31 @@ func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode
router.HandlerPackages = handlerMap
}

if pkgGen.SnakeStyleMiddleware { // unique middleware name for SnakeStyleMiddleware
mws := []string{}
hook := func(layer int, node *RouterNode) error {
if len(node.Children) == 0 {
return nil
}
groupMwName := node.GroupMiddleware
handlerMwName := node.HandlerMiddleware
if len(groupMwName) != 0 {
mws, groupMwName = appendMw(mws, groupMwName)
}
if len(handlerMwName) != 0 {
mws, handlerMwName = appendMw(mws, handlerMwName)
}
if groupMwName != node.GroupMiddleware {
node.GroupMiddleware = groupMwName
}
if handlerMwName != node.HandlerMiddleware {
node.HandlerMiddleware = handlerMwName
}
return nil
}
root.DFS(0, hook)
}

// store router info
pkg.RouterInfo = &router

Expand Down
1 change: 1 addition & 0 deletions cmd/hz/protobuf/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps
BaseDomain: args.BaseDomain,
QueryEnumAsInt: args.QueryEnumAsInt,
SnakeStyleMiddleware: args.SnakeStyleMiddleware,
SortRouter: args.SortRouter,
}

if args.ModelBackend != "" {
Expand Down
5 changes: 5 additions & 0 deletions cmd/hz/thrift/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ func (plugin *Plugin) Run() int {
BaseDomain: args.BaseDomain,
QueryEnumAsInt: args.QueryEnumAsInt,
SnakeStyleMiddleware: args.SnakeStyleMiddleware,
SortRouter: args.SortRouter,
}
if args.ModelBackend != "" {
sg.Backend = meta.Backend(args.ModelBackend)
Expand Down Expand Up @@ -232,6 +233,10 @@ func (plugin *Plugin) handleRequest() error {
thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc())
thriftgoUtil.HandleOptions(req.GeneratorParameters)

// TEST: For plugin debug. Delete below codes after release!
buf, _ := thriftgo_plugin.MarshalRequest(req)
err = os.MkdirAll("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data", os.FileMode(0755))
err = os.WriteFile("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data/request_thrift.out", buf, os.FileMode(0755))
return nil
}

Expand Down
19 changes: 16 additions & 3 deletions cmd/hz/thrift/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package thrift

import (
"github.com/cloudwego/thriftgo/generator/backend"
"github.com/cloudwego/thriftgo/generator/golang"
"io/ioutil"
"testing"

Expand All @@ -26,7 +28,7 @@ import (
)

func TestRun(t *testing.T) {
data, err := ioutil.ReadFile("../testdata/request_thrift.out")
data, err := ioutil.ReadFile("/Users/bytedance/go/src/hertz_master/cmd/hz/thrift/test_data/request_thrift.out")
if err != nil {
t.Fatal(err)
}
Expand All @@ -35,6 +37,8 @@ func TestRun(t *testing.T) {
if err != nil {
t.Fatal(err)
}
thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc())
thriftgoUtil.HandleOptions(req.GeneratorParameters)

plu := new(Plugin)
plu.setLogger()
Expand Down Expand Up @@ -79,12 +83,21 @@ func TestRun(t *testing.T) {
HandlerDir: handlerDir,
RouterDir: routerDir,
ModelDir: modelDir,
UseDir: args.Use,
ClientDir: clientDir,
TemplateGenerator: generator.TemplateGenerator{
OutputDir: args.OutDir,
Excludes: args.Excludes,
},
ProjPackage: pkg,
Options: options,
ProjPackage: pkg,
Options: options,
HandlerByMethod: args.HandlerByMethod,
CmdType: args.CmdType,
ForceClientDir: args.ForceClientDir,
BaseDomain: args.BaseDomain,
QueryEnumAsInt: args.QueryEnumAsInt,
SnakeStyleMiddleware: args.SnakeStyleMiddleware,
SortRouter: args.SortRouter,
}
if args.ModelBackend != "" {
sg.Backend = meta.Backend(args.ModelBackend)
Expand Down
Binary file added cmd/hz/thrift/test_data/request_thrift.out
Binary file not shown.

0 comments on commit 7c2841f

Please sign in to comment.