-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfuncdispatcher.go
156 lines (133 loc) · 4.54 KB
/
funcdispatcher.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package command
import (
"context"
"fmt"
"reflect"
)
// type getReplacementValFunc func() reflect.Value
// var (
// ReplaceArgTypes = map[reflect.Type]getReplacementValFunc{
// typeOfContext: func() reflect.Value { return reflect.ValueOf(context.TODO()) },
// }
// )
// type argReplacement struct {
// argIndex int // index of the argument
// insert bool // if the argument should be inserted or replaced
// getReplacementVal getReplacementValFunc
// }
// functionArgTypesWithoutReplaceables returns the function argument types except for
// the first argument of type context.Context and callback function arguments.
func functionArgTypesWithoutReplaceables(funcType reflect.Type) (argTypes []reflect.Type, firstArgIsContext bool, insertArgs []insertArg) {
numArgs := funcType.NumIn()
argTypes = make([]reflect.Type, 0, numArgs)
for i := 0; i < numArgs; i++ {
t := funcType.In(i)
if i == 0 && t == typeOfContext {
firstArgIsContext = true
continue
}
if t.Kind() == reflect.Func {
insertArgs = append(insertArgs, insertArg{index: i, value: reflect.Zero(t)})
continue
}
// _, hasPlaceholder := ReplaceArgTypes[t]
// if !hasPlaceholder {
// argTypes = append(argTypes, t)
// }
argTypes = append(argTypes, t)
}
return argTypes, firstArgIsContext, insertArgs
}
type insertArg struct {
index int
value reflect.Value
}
type funcDispatcher struct {
argsDef *ArgsDef
funcVal reflect.Value
funcType reflect.Type
// argReplacements []argReplacement
firstArgIsContext bool
insertArgs []insertArg
errorIndex int
}
func newFuncDispatcher(argsDef *ArgsDef, commandFunc interface{}) (disp *funcDispatcher, err error) {
disp = new(funcDispatcher)
disp.argsDef = argsDef
disp.funcVal = reflect.ValueOf(commandFunc)
disp.funcType = disp.funcVal.Type()
if disp.funcType.Kind() != reflect.Func {
return nil, fmt.Errorf("expected a function or method, but got %s", disp.funcType)
}
numResults := disp.funcType.NumOut()
if numResults > 0 && disp.funcType.Out(numResults-1) == typeOfError {
disp.errorIndex = numResults - 1
} else {
disp.errorIndex = -1
}
// disp.argReplacements = nil // TODO
var funcArgTypes []reflect.Type
funcArgTypes, disp.firstArgIsContext, disp.insertArgs = functionArgTypesWithoutReplaceables(disp.funcType)
numArgsDef := len(argsDef.argStructFields)
if numArgsDef != len(funcArgTypes) {
return nil, fmt.Errorf("number of fields in command.Args struct (%d) does not match number of function arguments (%d)", numArgsDef, len(funcArgTypes))
}
for i := range argsDef.argStructFields {
if argsDef.argStructFields[i].Field.Type != funcArgTypes[i] {
return nil, fmt.Errorf(
"type of command.Args struct field '%s' is %s, which does not match function argument %d type %s",
argsDef.argStructFields[i].Field.Name,
argsDef.argStructFields[i].Field.Type,
i,
funcArgTypes[i],
)
}
}
return disp, nil
}
func (disp *funcDispatcher) callWithResultsHandlers(ctx context.Context, argVals []reflect.Value, resultsHandlers []ResultsHandler) error {
if disp.firstArgIsContext {
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
}
for _, insert := range disp.insertArgs {
argVals = append(argVals[:insert.index], append([]reflect.Value{insert.value}, argVals[insert.index:]...)...)
}
var resultVals []reflect.Value
if disp.funcType.IsVariadic() {
resultVals = disp.funcVal.CallSlice(argVals)
} else {
resultVals = disp.funcVal.Call(argVals)
}
var resultErr error
if disp.errorIndex != -1 {
resultErr, _ = resultVals[disp.errorIndex].Interface().(error)
resultVals = resultVals[:disp.errorIndex]
}
for _, resultsHandler := range resultsHandlers {
err := resultsHandler.HandleResults(disp.argsDef, argVals, resultVals, resultErr)
if err != nil && err != resultErr {
return err
}
}
return resultErr
}
func (disp *funcDispatcher) callAndReturnResults(ctx context.Context, argVals []reflect.Value) ([]reflect.Value, error) {
if disp.firstArgIsContext {
argVals = append([]reflect.Value{reflect.ValueOf(ctx)}, argVals...)
}
for _, insert := range disp.insertArgs {
argVals = append(argVals[:insert.index], append([]reflect.Value{insert.value}, argVals[insert.index:]...)...)
}
var resultVals []reflect.Value
if disp.funcType.IsVariadic() {
resultVals = disp.funcVal.CallSlice(argVals)
} else {
resultVals = disp.funcVal.Call(argVals)
}
var resultErr error
if disp.errorIndex != -1 {
resultErr, _ = resultVals[disp.errorIndex].Interface().(error)
resultVals = resultVals[:disp.errorIndex]
}
return resultVals, resultErr
}