Skip to content

Commit

Permalink
Use regex, Luke
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin committed Sep 24, 2024
1 parent 4336b0d commit 1447638
Showing 1 changed file with 10 additions and 46 deletions.
56 changes: 10 additions & 46 deletions cmd/tools/genrpcserverinterceptors/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ import (
"os"
"path"
"reflect"
"sort"
"strconv"
"regexp"
"strings"
"text/template"

Expand Down Expand Up @@ -71,18 +70,8 @@ var (
reflect.TypeOf((*matchingservice.MatchingServiceServer)(nil)).Elem(),
}

// Only request fields that end with suffixes or match exact name are eligible for deeper inspection.
fieldNameSuffixes = []string{
"Request",
}
fieldNames = map[string]struct{}{
"Completion": {},
"UpdateRef": {},
"ParentExecution": {},
"WorkflowState": {},
"ExecutionInfo": {},
"ExecutionState": {},
}
// Only request fields that match the pattern are eligible for deeper inspection.
fieldNameRegex = regexp.MustCompile("^(?:.*Request|Completion|UpdateRef|ParentExecution|WorkflowState|ExecutionInfo|ExecutionState)$")

// These types have task_token field, but it is not of type *tokenspb.Task and doesn't have Workflow tags.
excludeTaskTokenTypes = []reflect.Type{
Expand Down Expand Up @@ -215,26 +204,23 @@ func workflowTagGetters(requestT reflect.Type, depth int) requestData {
}
}

for _, nestedRequest := range subFields(requestT) {
for fieldNum := 0; fieldNum < requestT.Elem().NumField(); fieldNum++ {
// Iterates over fields in order they defined in proto file, not proto index.
// Order is important because the first match wins.
nestedRequest := requestT.Elem().Field(fieldNum)

if nestedRequest.Type.Kind() != reflect.Ptr {
continue
}
if nestedRequest.Type.Elem().Kind() != reflect.Struct {
continue
}
hasAllowedSuffix := false
for _, suffix := range fieldNameSuffixes {
if strings.HasSuffix(nestedRequest.Name, suffix) {
hasAllowedSuffix = true
break
}
}
if _, hasAllowedName := fieldNames[nestedRequest.Name]; !hasAllowedName && !hasAllowedSuffix {
if !fieldNameRegex.MatchString(nestedRequest.Name) {
continue
}

nestedRd := workflowTagGetters(nestedRequest.Type, depth+1)
// First match wins. If getter is already set, it won't be overwritten.
// First match wins: if getter is already set, it won't be overwritten.
if rd.WorkflowIdGetter == "" && nestedRd.WorkflowIdGetter != "" {
rd.WorkflowIdGetter = fmt.Sprintf("Get%s().%s", nestedRequest.Name, nestedRd.WorkflowIdGetter)
}
Expand All @@ -248,28 +234,6 @@ func workflowTagGetters(requestT reflect.Type, depth int) requestData {
return rd
}

func subFields(t reflect.Type) []reflect.StructField {
// This function returns subfields ordered in proto index order. Tag is the string like:
// `protobuf:"bytes,2,opt,name=workflow_id,json=workflowId,proto3" json:"workflow_id,omitempty"`
// ^ - this number is used for ordering.

var fields []reflect.StructField
for fieldNum := 0; fieldNum < t.Elem().NumField(); fieldNum++ {
f := t.Elem().Field(fieldNum)
if _, ok := f.Tag.Lookup("protobuf"); ok {
fields = append(fields, f)
}
}
protoOrder := func(tag reflect.StructTag) int {
o, _ := strconv.Atoi(strings.Split(tag.Get("protobuf"), ",")[1])
return o
}
sort.Slice(fields, func(i, j int) bool {
return protoOrder(fields[i].Tag) < protoOrder(fields[j].Tag)
})
return fields
}

func callWithFile(generator func(io.Writer, reflect.Type), server reflect.Type, outPath string, licenseText string) {
filename := path.Join(outPath, camelCaseToSnakeCase(server.Name())+"_gen.go")
w, err := os.Create(filename)
Expand Down

0 comments on commit 1447638

Please sign in to comment.