diff --git a/encoding.go b/encoding.go index cd1d07e2..ef300a89 100644 --- a/encoding.go +++ b/encoding.go @@ -111,6 +111,39 @@ func MakeTypedEncoder(f interface{}) func(*Request) func(io.Writer) Encoder { }) } +// MakeMultiEncoder takes pairs of arguments, where the first of the pair is a value that denotes a type and the second denotes an encoder. +// The resulting encoder then uses the encoder of the pair with matching first element. +// Example: +// e := MakeMultiEncoder( +// "string", MakeTypedEncoder(func(req *Request, w io.Writer, str string) error { +// // ... +// return nil +// }, +// cmdkit.Error{}, MakeTypedEncoder(func(req *Request, w io.Writer, err cmdkit.Error) error { +// // ... +// return nil +// })) +func MakeMultiEncoder(args ...interface{}) func(*Request) func(io.Writer) Encoder { + if len(args)%2 != 0 { + panic("MakeMultiEncoder must receive an even number of parameters") + } + + types := make(map[reflect.Type]func(*Request) func(io.Writer) Encoder) + + for i := 0; i < len(args); i += 2 { + types[reflect.TypeOf(args[i])] = args[i+1].(func(*Request) func(io.Writer) Encoder) + } + + return MakeEncoder(func(req *Request, w io.Writer, i interface{}) error { + f, ok := types[reflect.TypeOf(i)] + if !ok { + return fmt.Errorf("unexpected type: %T", i) + } + + return f(req)(w).Encode(i) + }) +} + type genericEncoder struct { f func(*Request, io.Writer, interface{}) error w io.Writer diff --git a/encoding_test.go b/encoding_test.go index 51156e59..739a6d33 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -11,6 +11,10 @@ type fooTestObj struct { Good bool } +type barTestObj struct { + Bad bool +} + func TestMakeTypedEncoder(t *testing.T) { expErr := fmt.Errorf("command fooTestObj failed") f := MakeTypedEncoder(func(req *Request, w io.Writer, v *fooTestObj) error { @@ -55,3 +59,45 @@ func TestMakeTypedEncoderArrays(t *testing.T) { t.Fatal(err) } } + +func TestMakeMultiEncoder(t *testing.T) { + expErrFoo := fmt.Errorf("command fooTestObj failed") + expErrBar := fmt.Errorf("command barTestObj failed") + + f := MakeMultiEncoder( + &fooTestObj{}, MakeTypedEncoder(func(req *Request, w io.Writer, v *fooTestObj) error { + if v.Good { + return nil + } + return expErrFoo + }), + &barTestObj{}, MakeTypedEncoder(func(req *Request, w io.Writer, v *barTestObj) error { + if !v.Bad { + return nil + } + return expErrBar + })) + + req := &Request{} + + encoderFunc := f(req) + + buf := new(bytes.Buffer) + encoder := encoderFunc(buf) + + if err := encoder.Encode(&fooTestObj{true}); err != nil { + t.Fatal(err) + } + + if err := encoder.Encode(&fooTestObj{false}); err != expErrFoo { + t.Fatal("expected: ", expErrFoo) + } + + if err := encoder.Encode(&barTestObj{true}); err != expErrBar { + t.Fatal("expected: ", expErrBar) + } + + if err := encoder.Encode(&barTestObj{false}); err != nil { + t.Fatal(err) + } +}