diff --git a/chainable.go b/chainable.go index 19bccb6..66e24b2 100644 --- a/chainable.go +++ b/chainable.go @@ -97,7 +97,7 @@ func (lk *link) process(linkIndex int, args []Argument) ([]Argument, error) { // call the function out := []Argument{} - for _, o := range vfn.Call(reflectArgs(args)) { + for _, o := range vfn.Call(reflectArgs(vfnType, args)) { out = append(out, o.Interface()) } @@ -147,11 +147,17 @@ func validateArgs(linkIndex int, fn reflect.Type, args []Argument) error { // reflectArgs transforms the args list in a list of // reflect.Value, used to call a function using reflection -func reflectArgs(args []Argument) []reflect.Value { +func reflectArgs(fnType reflect.Type, args []Argument) []reflect.Value { in := make([]reflect.Value, len(args)) for k, arg := range args { - in[k] = reflect.ValueOf(arg) + if arg == nil { + // Use the zero value of the function parameter type, + // since "reflect.Call" doesn't accept "nil" parameters + in[k] = reflect.New(fnType.In(k)).Elem() + } else { + in[k] = reflect.ValueOf(arg) + } } return in diff --git a/chainable_test.go b/chainable_test.go index 4b4f993..cb3d2ce 100644 --- a/chainable_test.go +++ b/chainable_test.go @@ -99,6 +99,16 @@ func TestChain(t *testing.T) { returnValue: []Argument{genericStruct{}}, err: nil, }, + { + desc: "With 'nil' value feeded to the chain", + from: []Argument{1, 2, nil}, + funcs: []Function{ + func(a, b int, e error) (int, int, error) { return a, b, e }, + func(a, b int) (int, int) { return a, b }, + }, + returnValue: []Argument{1, 2}, + err: nil, + }, } for _, tc := range testCases { @@ -140,9 +150,8 @@ func TestChainDummy(t *testing.T) { }, { desc: "With cascading error", - from: []Argument{}, + from: []Argument{errGeneric}, funcs: []Function{ - func() error { return errGeneric }, func(e error) error { return e }, func(e error) error { return e }, }, @@ -150,13 +159,13 @@ func TestChainDummy(t *testing.T) { err: nil, }, { - desc: "With cascading error", - from: []Argument{errGeneric}, + desc: "Without argument feedback", + from: []Argument{}, funcs: []Function{ - func(e error) error { return e }, - func(e error) error { return e }, + func() {}, + func() {}, }, - returnValue: []Argument{errGeneric}, + returnValue: []Argument{}, err: nil, }, { @@ -179,6 +188,16 @@ func TestChainDummy(t *testing.T) { returnValue: []Argument{genericStruct{}, errGeneric}, err: nil, }, + { + desc: "With 'nil' value feeded to the chain", + from: []Argument{1, 2, nil}, + funcs: []Function{ + func(a, b int, e error) (int, int, error) { return a, b, e }, + func(a, b int, e error) (int, int, error) { return a, b, e }, + }, + returnValue: []Argument{1, 2, nil}, + err: nil, + }, } for _, tc := range testCases {