diff --git a/assert/nop.go b/assert/nop.go new file mode 100644 index 00000000..938cdf9c --- /dev/null +++ b/assert/nop.go @@ -0,0 +1,8 @@ +package assert + +// Nop returns an assertion that does not assert anything. +func Nop() Assertion { + return AssertionFunc(func(v interface{}) error { + return nil + }) +} diff --git a/internal/queryutil/query.go b/internal/queryutil/query.go index ac75ca17..93a7019f 100644 --- a/internal/queryutil/query.go +++ b/internal/queryutil/query.go @@ -1,10 +1,14 @@ package queryutil import ( + "context" + "reflect" + "strings" "sync" query "github.com/zoncoen/query-go" yamlextractor "github.com/zoncoen/query-go/extractor/yaml" + "google.golang.org/protobuf/types/dynamicpb" ) var ( @@ -23,11 +27,72 @@ func Options() []query.Option { []query.Option{ query.ExtractByStructTag("yaml", "json"), query.CustomExtractFunc(yamlextractor.MapSliceExtractFunc()), + query.CustomExtractFunc(dynamicpbExtractFunc()), }, opts..., ) } +func dynamicpbExtractFunc() func(query.ExtractFunc) query.ExtractFunc { + return func(f query.ExtractFunc) query.ExtractFunc { + return func(in reflect.Value) (reflect.Value, bool) { + v := in + if v.IsValid() && v.CanInterface() { + if msg, ok := v.Interface().(*dynamicpb.Message); ok { + return f(reflect.ValueOf(&keyExtractor{ + v: msg, + })) + } + } + return f(in) + } + } +} + +type keyExtractor struct { + v *dynamicpb.Message +} + +// ExtractByKey implements the query.KeyExtractorContext interface. +func (e *keyExtractor) ExtractByKey(ctx context.Context, key string) (interface{}, bool) { + ci := query.IsCaseInsensitive(ctx) + if ci { + key = strings.ToLower(key) + } + fields := e.v.Descriptor().Fields() + for i := range fields.Len() { + f := fields.Get(i) + { + name := string(f.Name()) + if ci { + name = strings.ToLower(name) + } + if name == key { + return e.v.Get(f).Interface(), true + } + } + { + name := f.TextName() + if ci { + name = strings.ToLower(name) + } + if name == key { + return e.v.Get(f).Interface(), true + } + } + if f.HasJSONName() { + name := f.JSONName() + if ci { + name = strings.ToLower(name) + } + if name == key { + return e.v.Get(f).Interface(), true + } + } + } + return nil, false +} + func AppendOptions(customOpts ...query.Option) { m.Lock() defer m.Unlock() diff --git a/internal/queryutil/query_test.go b/internal/queryutil/query_test.go new file mode 100644 index 00000000..1ac4c4ec --- /dev/null +++ b/internal/queryutil/query_test.go @@ -0,0 +1,92 @@ +package queryutil + +import ( + "context" + "strings" + "testing" + + "github.com/zoncoen/query-go" + "github.com/zoncoen/scenarigo/protocol/grpc/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +func TestKeyExtractor_ExtractByKey(t *testing.T) { + comp := proto.NewCompiler([]string{}) + fds, err := comp.Compile(context.Background(), []string{"testdata/test.proto"}) + if err != nil { + t.Fatal(err) + } + svc, err := fds.ResolveService("scenarigo.testdata.test.Test") + if err != nil { + t.Fatal(err) + } + method := svc.Methods().ByName("Echo") + if method == nil { + t.Fatal("failed to get method") + } + msg := dynamicpb.NewMessage(method.Input()) + msg.Set(method.Input().Fields().ByName("message_id"), protoreflect.ValueOf("1")) + + t.Run("success", func(t *testing.T) { + tests := map[string]struct { + key string + opts []query.Option + expect any + }{ + "name": { + key: "message_id", + expect: "1", + }, + "name (case insensitive)": { + key: "MESSAGE_ID", + opts: []query.Option{query.CaseInsensitive()}, + expect: "1", + }, + "json": { + key: "messageId", + expect: "1", + }, + "json (case insensitive)": { + key: "messageid", + opts: []query.Option{query.CaseInsensitive()}, + expect: "1", + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + v, err := New(test.opts...).Key(test.key).Extract(msg) + if err != nil { + t.Fatal(err) + } + if got, expect := v, test.expect; got != expect { + t.Errorf("expect %v but got %v", expect, got) + } + }) + } + }) + + t.Run("failure", func(t *testing.T) { + tests := map[string]struct { + key string + opts []query.Option + expect string + }{ + "not found": { + key: "messageid", + expect: `".messageid" not found`, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + _, err := New(test.opts...).Key(test.key).Extract(msg) + if err == nil { + t.Fatal("no error") + } + if got, expect := err.Error(), test.expect; !strings.Contains(got, test.expect) { + t.Errorf("expect %q but got %q", expect, got) + } + }) + } + }) +} diff --git a/internal/queryutil/testdata/test.proto b/internal/queryutil/testdata/test.proto new file mode 100644 index 00000000..80985ede --- /dev/null +++ b/internal/queryutil/testdata/test.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package scenarigo.testdata.test; + +service Test { + rpc Echo(EchoRequest) returns (EchoResponse) {}; +} + +message EchoRequest { + string message_id = 1; + string message_body = 2; +} + +message EchoResponse { + string message_id = 1; + string message_body = 2; +} diff --git a/internal/yamlutil/grpc.go b/internal/yamlutil/grpc.go new file mode 100644 index 00000000..fbce21cc --- /dev/null +++ b/internal/yamlutil/grpc.go @@ -0,0 +1,34 @@ +package yamlutil + +import ( + "encoding/hex" + "strings" + "unicode/utf8" + + "google.golang.org/grpc/metadata" + + "github.com/goccy/go-yaml" +) + +func NewMDMarshaler(md metadata.MD) *MDMarshaler { return (*MDMarshaler)(&md) } + +type MDMarshaler metadata.MD + +func (m *MDMarshaler) MarshalYAML() ([]byte, error) { + mp := make(metadata.MD, len(*m)) + for k, vs := range *m { + if !strings.HasSuffix(k, "-bin") { + mp[k] = vs + continue + } + s := make([]string, len(vs)) + for i, v := range vs { + if !utf8.ValidString(v) { + v = hex.EncodeToString([]byte(v)) + } + s[i] = v + } + mp[k] = s + } + return yaml.Marshal(mp) +} diff --git a/mock/protocol/grpc/grpc.go b/mock/protocol/grpc/grpc.go new file mode 100644 index 00000000..9048ca84 --- /dev/null +++ b/mock/protocol/grpc/grpc.go @@ -0,0 +1,204 @@ +package grpc + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/goccy/go-yaml" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + + "github.com/zoncoen/scenarigo/logger" + "github.com/zoncoen/scenarigo/mock/protocol" + "github.com/zoncoen/scenarigo/protocol/grpc/proto" +) + +var waitInterval = 100 * time.Millisecond + +// Register registers grpc protocol. +func Register() { + protocol.Register(&GRPC{}) +} + +// GRPC is a protocol type for the mock. +type GRPC struct{} + +// Name implements protocol.Protocol interface. +func (_ GRPC) Name() string { return "grpc" } //nolint:revive + +// UnmarshalConfig implements protocol.Protocol interface. +func (_ GRPC) UnmarshalConfig(b []byte) (interface{}, error) { //nolint:revive + var config ServerConfig + if err := yaml.Unmarshal(b, &config); err != nil { + return nil, err + } + return &config, nil +} + +// NewServer implements protocol.Protocol interface. +func (_ *GRPC) NewServer(iter *protocol.MockIterator, l logger.Logger, config interface{}) (protocol.Server, error) { //nolint:revive + if iter == nil { + return nil, errors.New("mock iterator is nil") + } + cfg, ok := config.(*ServerConfig) + if !ok { + return nil, fmt.Errorf("invalid config %T", config) + } + srv := &server{ + iter: iter, + } + if cfg != nil { + srv.config = *cfg + comp := proto.NewCompiler(cfg.Proto.Imports) + fds, err := comp.Compile(context.Background(), cfg.Proto.Files) + if err != nil { + return nil, fmt.Errorf("failed to compile proto: %w", err) + } + srv.resolver = fds + } + return srv, nil +} + +// ServerConfig represents a server configuration. +type ServerConfig struct { + Port int `yaml:"port,omitempty"` + Proto ProtoConfig `yaml:"proto,omitempty"` +} + +// ProtoConfig represents a proto configuration. +type ProtoConfig struct { + Imports []string `yaml:"imports,omitempty"` + Files []string `yaml:"files,omitempty"` +} + +type server struct { + m sync.Mutex + config ServerConfig + iter *protocol.MockIterator + resolver proto.ServiceDescriptorResolver + addr string + srv *grpc.Server +} + +// Start implements protocol.Server interface. +func (s *server) Start(ctx context.Context) error { + s.m.Lock() + serve, err := s.setup() + if err != nil { + s.m.Unlock() + return err + } + s.m.Unlock() + return serve() +} + +func (s *server) setup() (func() error, error) { + if s.srv != nil { + return nil, errors.New("server already started") + } + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.config.Port)) + if err != nil { + return nil, fmt.Errorf("failed to listen: %w", err) + } + s.addr = ln.Addr().String() + s.srv = grpc.NewServer() + healthpb.RegisterHealthServer(s.srv, &healthServer{}) + names, err := s.resolver.ListServices() + if err != nil { + return nil, fmt.Errorf("failed to get service descriptor: %w", err) + } + for _, name := range names { + sd, err := s.resolver.ResolveService(name) + if err != nil { + return nil, fmt.Errorf("failed to get service descriptor: %w", err) + } + s.srv.RegisterService(s.convertToServicDesc(sd), nil) + } + return func() error { + if err := s.srv.Serve(ln); err != nil { + if !errors.Is(err, grpc.ErrServerStopped) { + return err + } + } + return nil + }, nil +} + +// Wait implements protocol.Server interface. +func (s *server) Wait(ctx context.Context) error { + ch := make(chan error) + go func() { + ch <- s.wait(ctx) + }() + select { + case <-ctx.Done(): + return context.Canceled + case err := <-ch: + return err + } +} + +func (s *server) wait(ctx context.Context) error { + var once sync.Once + var client healthpb.HealthClient + for { + if err := ctx.Err(); err != nil { + return err + } + s.m.Lock() + srv := s.srv + s.m.Unlock() + if srv != nil { + var err error + once.Do(func() { + c, cErr := grpc.NewClient(s.addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if cErr != nil { + err = fmt.Errorf("failed to connect server: %w", err) + return + } + client = healthpb.NewHealthClient(c) + }) + if err != nil { + return err + } + resp, err := client.Check(ctx, &healthpb.HealthCheckRequest{ + Service: "grpc.health.v1", + }) + if err == nil { + if resp.GetStatus() == healthpb.HealthCheckResponse_SERVING { + return nil + } + } + } + time.Sleep(waitInterval) + } +} + +// Stop implements protocol.Server interface. +func (s *server) Stop(ctx context.Context) error { + s.m.Lock() + defer s.m.Unlock() + if s.srv == nil { + return protocol.ErrServerClosed + } + s.addr = "" + srv := s.srv + s.srv = nil + srv.GracefulStop() // GracefulStop() calls s.ln.Close() + return nil +} + +// Addr implements protocol.Server interface. +func (s *server) Addr() (string, error) { + s.m.Lock() + defer s.m.Unlock() + if s.srv == nil { + return "", protocol.ErrServerClosed + } + return s.addr, nil +} diff --git a/mock/protocol/grpc/grpc_test.go b/mock/protocol/grpc/grpc_test.go new file mode 100644 index 00000000..d17187cd --- /dev/null +++ b/mock/protocol/grpc/grpc_test.go @@ -0,0 +1,201 @@ +package grpc + +import ( + "context" + "os" + "strings" + "testing" + "time" + + "github.com/goccy/go-yaml" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" + + "github.com/zoncoen/scenarigo/logger" + "github.com/zoncoen/scenarigo/mock/protocol" + testpb "github.com/zoncoen/scenarigo/testdata/gen/pb/test" +) + +func init() { + Register() +} + +func TestGRPC_Server(t *testing.T) { + cfg := ` +proto: + files: + - ./testdata/test.proto +` + + tests := map[string]struct { + filename string + config string + f func(*testing.T, string) + }{ + "empty": { + filename: "testdata/empty.yaml", + f: func(t *testing.T, addr string) { + t.Helper() + c, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to connect server: %s", err) + } + client := healthpb.NewHealthClient(c) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + resp, err := client.Check(ctx, &healthpb.HealthCheckRequest{ + Service: healthpb.Health_ServiceDesc.ServiceName, + }) + if err != nil { + t.Fatal(err) + } + if got, expect := resp.GetStatus(), healthpb.HealthCheckResponse_SERVING; got != expect { + t.Errorf("expect %d but got %d", expect, got) + } + }, + }, + "success": { + filename: "testdata/grpc.yaml", + config: cfg, + f: sendEchoRequest(nil, "1", "hello"), + }, + "int status code": { + filename: "testdata/int-status-code.yaml", + config: cfg, + f: sendEchoRequest(nil, "1", "hello"), + }, + "no expect and response": { + filename: "testdata/no-expect-response.yaml", + config: cfg, + f: sendEchoRequest(nil, "", ""), + }, + "unauthenticated": { + filename: "testdata/unauthenticated.yaml", + config: cfg, + f: sendEchoRequest(status.New(codes.Unauthenticated, "Unauthenticated"), "", ""), + }, + "invalid expect service": { + filename: "testdata/invalid-expect-service.yaml", + config: cfg, + f: sendEchoRequest(status.New(codes.InvalidArgument, ".expect.service: request assertion failed"), "", ""), + }, + "invalid expect method": { + filename: "testdata/invalid-expect-method.yaml", + config: cfg, + f: sendEchoRequest(status.New(codes.InvalidArgument, ".expect.method: request assertion failed"), "", ""), + }, + "invalid expect metadata": { + filename: "testdata/invalid-expect-metadata.yaml", + config: cfg, + f: sendEchoRequest(status.New(codes.InvalidArgument, ".expect.metadata.content-type: request assertion failed"), "", ""), + }, + "invalid expect message": { + filename: "testdata/invalid-expect-metadata.yaml", + config: cfg, + f: sendEchoRequest(status.New(codes.InvalidArgument, ".expect.metadata.content-type: request assertion failed"), "", ""), + }, + } + for name, test := range tests { + test := test + t.Run(name, func(t *testing.T) { + p := protocol.Get("grpc") + if p == nil { + t.Fatal("failed to get protocol") + } + f, err := os.Open(test.filename) + if err != nil { + t.Fatal(err) + } + defer f.Close() + var mocks []protocol.Mock + if err := yaml.NewDecoder(f).Decode(&mocks); err != nil { + t.Fatal(err) + } + iter := protocol.NewMockIterator(mocks) + defer func() { + if err := iter.Stop(); err != nil { + t.Errorf("failed to stop mock iterator: %s", err) + } + }() + + // unmarshal config + cfg, err := p.UnmarshalConfig([]byte(test.config)) + if err != nil { + t.Fatalf("failed to unmarshal config: %s", err) + } + + // start server + srv, err := p.NewServer(iter, logger.NewNopLogger(), cfg) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + go func() { + if err := srv.Start(context.Background()); err != nil { + t.Errorf("failed to start server: %s", err) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Wait(ctx); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + addr, err := srv.Addr() + if err != nil { + t.Errorf("failed to get address: %s", err) + } + test.f(t, addr) + + // stop server + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Stop(ctx); err != nil { + t.Fatalf("failed to stop server: %s", err) + } + }) + } +} + +func sendEchoRequest(st *status.Status, id, msg string) func(t *testing.T, addr string) { + return func(t *testing.T, addr string) { + t.Helper() + c, err := grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("failed to connect server: %s", err) + } + client := testpb.NewTestClient(c) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + resp, err := client.Echo(ctx, &testpb.EchoRequest{ + MessageId: "1", + MessageBody: "hello", + }) + if err != nil { + serr := status.Convert(err) + if st == nil { + t.Fatalf("expect status code %s but got %s", codes.OK, serr.Code()) + } + if got, expect := serr.Code(), st.Code(); got != expect { + t.Errorf("expect status code %s but got %s", expect, got) + } + if got, expect := serr.Message(), st.Message(); !strings.Contains(got, expect) { + t.Errorf("expect status message %s but got %s", expect, got) + } + return + } + if st != nil { + if got, expect := codes.OK, st.Code(); got != expect { + t.Errorf("expect %s but got %s", expect, got) + } + } + if got, expect := resp.GetMessageId(), id; got != expect { + t.Errorf("expect %s but got %s", expect, got) + } + if got, expect := resp.GetMessageBody(), msg; got != expect { + t.Errorf("expect %s but got %s", expect, got) + } + } +} diff --git a/mock/protocol/grpc/handler.go b/mock/protocol/grpc/handler.go new file mode 100644 index 00000000..f917c7ef --- /dev/null +++ b/mock/protocol/grpc/handler.go @@ -0,0 +1,257 @@ +package grpc + +import ( + gocontext "context" + "fmt" + "math" + "strconv" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/goccy/go-yaml" + + "github.com/zoncoen/scenarigo/assert" + "github.com/zoncoen/scenarigo/context" + "github.com/zoncoen/scenarigo/errors" + "github.com/zoncoen/scenarigo/internal/assertutil" + "github.com/zoncoen/scenarigo/internal/yamlutil" + grpcprotocol "github.com/zoncoen/scenarigo/protocol/grpc" +) + +func (s *server) convertToServicDesc(sd protoreflect.ServiceDescriptor) *grpc.ServiceDesc { + desc := &grpc.ServiceDesc{ + ServiceName: string(sd.FullName()), + Metadata: sd.ParentFile().Path(), + } + for i := 0; i < sd.Methods().Len(); i++ { + m := sd.Methods().Get(i) + // TODO: streaming RPC + // if m.IsStreamingServer() || m.IsStreamingClient() { + // desc.Streams = append(desc.Streams, grpc.StreamDesc{ + // StreamName: string(m.Name()), + // ServerStreams: m.IsStreamingServer(), + // ClientStreams: m.IsStreamingClient(), + // Handler: func(srv any, stream grpc.ServerStream) error { + // return nil + // }, + // }) + // } else { + desc.Methods = append(desc.Methods, grpc.MethodDesc{ + MethodName: string(m.Name()), + Handler: s.unaryHandler(sd.FullName(), m), + }) + // } + } + return desc +} + +func (s *server) unaryHandler(svcName protoreflect.FullName, method protoreflect.MethodDescriptor) func(srv any, ctx gocontext.Context, dec func(any) error, interceptor grpc.UnaryServerInterceptor) (any, error) { + return func(srv any, ctx gocontext.Context, dec func(any) error, interceptor grpc.UnaryServerInterceptor) (any, error) { + mock, err := s.iter.Next() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get mock: %s", err) + } + + if mock.Protocol != "grpc" { + return nil, status.Error(codes.Internal, errors.WithPath(fmt.Errorf("received gRPC request but the mock protocol is %q", mock.Protocol), "protocol").Error()) + } + + var e expect + if err := mock.Expect.Unmarshal(&e); err != nil { + return nil, status.Error(codes.Internal, errors.WrapPath(err, "expect", "failed to unmarshal").Error()) + } + assertion, err := e.build(context.New(nil)) + if err != nil { + return nil, status.Error(codes.Internal, errors.WrapPath(err, "expect", "failed to build assretion").Error()) + } + + var md metadata.MD + if got, ok := metadata.FromIncomingContext(ctx); ok { + md = got + } + req := dynamicpb.NewMessage(method.Input()) + if err := dec(req); err != nil { + return nil, status.Error(codes.Internal, errors.WrapPath(err, "expect.message", "failed to decode message").Error()) + } + if err := assertion.Assert(&request{ + service: string(svcName), + method: string(method.Name()), + metadata: yamlutil.NewMDMarshaler(md), + message: req, + }); err != nil { + return nil, status.Error(codes.InvalidArgument, errors.WrapPath(err, "expect", "request assertion failed").Error()) + } + + var resp Response + if err := mock.Response.Unmarshal(&resp); err != nil { + return nil, status.Error(codes.Internal, errors.WrapPath(err, "response", "failed to unmarshal response").Error()) + } + sctx := context.New(nil) + v, err := sctx.ExecuteTemplate(resp) + if err != nil { + return nil, status.Error(codes.Internal, errors.WrapPath(err, "response", "failed to execute template of response").Error()) + } + resp, ok := v.(Response) + if !ok { + return nil, status.Error(codes.Internal, errors.WithPath(fmt.Errorf("failed to execute template of response: unexpected type %T", v), "response").Error()) + } + + var msg proto.Message = dynamicpb.NewMessage(method.Output()) + msg, serr, err := resp.extract(msg) + if err != nil { + return nil, status.Error(codes.Internal, errors.WithPath(err, "response").Error()) + } + return msg, serr.Err() + } +} + +type request struct { + service string + method string + metadata *yamlutil.MDMarshaler + message any +} + +type expect struct { + Service *string `yaml:"service"` + Method *string `yaml:"method"` + Metadata yaml.MapSlice `yaml:"metadata"` + Message any `yaml:"message"` +} + +func (e *expect) build(ctx *context.Context) (assert.Assertion, error) { + var ( + serviceAssertion = assert.Nop() + methodAssertion = assert.Nop() + err error + ) + if e.Service != nil { + serviceAssertion, err = assert.Build(ctx.RequestContext(), *e.Service, assert.FromTemplate(ctx)) + if err != nil { + return nil, errors.WrapPathf(err, "service", "invalid expect service") + } + } + if e.Method != nil { + methodAssertion, err = assert.Build(ctx.RequestContext(), *e.Method, assert.FromTemplate(ctx)) + if err != nil { + return nil, errors.WrapPathf(err, "method", "invalid expect method") + } + } + + metadataAssertion, err := assertutil.BuildHeaderAssertion(ctx, e.Metadata) + if err != nil { + return nil, errors.WrapPathf(err, "metadata", "invalid expect metadata") + } + + assertion, err := assert.Build(ctx.RequestContext(), e.Message, assert.FromTemplate(ctx)) + if err != nil { + return nil, errors.WrapPathf(err, "message", "invalid expect response message") + } + + return assert.AssertionFunc(func(v interface{}) error { + req, ok := v.(*request) + if !ok { + return errors.Errorf("expected request but got %T", v) + } + if err := serviceAssertion.Assert(req.service); err != nil { + return errors.WithPath(err, "service") + } + if err := methodAssertion.Assert(req.method); err != nil { + return errors.WithPath(err, "method") + } + if err := metadataAssertion.Assert(req.metadata); err != nil { + return errors.WithPath(err, "metadata") + } + if err := assertion.Assert(req.message); err != nil { + return errors.WithPath(err, "message") + } + return nil + }), nil +} + +// Response represents an gRPC response. +type Response grpcprotocol.Expect + +func (resp *Response) extract(msg proto.Message) (proto.Message, *status.Status, error) { + if resp.Status.Code != "" { + code := codes.OK + c, err := strToCode(resp.Status.Code) + if err != nil { + return nil, nil, errors.WithPath(err, "status.code") + } + code = c + + smsg := code.String() + if resp.Status.Message != "" { + smsg = resp.Status.Message + } + + if code != codes.OK { + return nil, status.New(code, smsg), nil + } + } + + if resp.Message != nil { + if err := grpcprotocol.ConvertToProto(resp.Message, msg); err != nil { + return nil, nil, errors.WrapPath(err, "message", "invalid message") + } + } + + return msg, nil, nil +} + +func strToCode(s string) (codes.Code, error) { + switch s { + case "OK": + return codes.OK, nil + case "Canceled": + return codes.Canceled, nil + case "Unknown": + return codes.Unknown, nil + case "InvalidArgument": + return codes.InvalidArgument, nil + case "DeadlineExceeded": + return codes.DeadlineExceeded, nil + case "NotFound": + return codes.NotFound, nil + case "AlreadyExists": + return codes.AlreadyExists, nil + case "PermissionDenied": + return codes.PermissionDenied, nil + case "ResourceExhausted": + return codes.ResourceExhausted, nil + case "FailedPrecondition": + return codes.FailedPrecondition, nil + case "Aborted": + return codes.Aborted, nil + case "OutOfRange": + return codes.OutOfRange, nil + case "Unimplemented": + return codes.Unimplemented, nil + case "Internal": + return codes.Internal, nil + case "Unavailable": + return codes.Unavailable, nil + case "DataLoss": + return codes.DataLoss, nil + case "Unauthenticated": + return codes.Unauthenticated, nil + } + if i, err := strconv.Atoi(s); err == nil { + return intToCode(i) + } + return codes.Unknown, fmt.Errorf("invalid status code %q", s) +} + +func intToCode(i int) (codes.Code, error) { + if i > math.MaxUint32 { + return 0, errors.Errorf("invalid status code %d: exceeds the maximum limit for uint32", i) + } + return codes.Code(i), nil +} diff --git a/mock/protocol/grpc/handler_test.go b/mock/protocol/grpc/handler_test.go new file mode 100644 index 00000000..e8a6b325 --- /dev/null +++ b/mock/protocol/grpc/handler_test.go @@ -0,0 +1,185 @@ +package grpc + +import ( + "context" + "errors" + "strings" + "testing" + + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/zoncoen/scenarigo/internal/yamlutil" + "github.com/zoncoen/scenarigo/mock/protocol" + "github.com/zoncoen/scenarigo/protocol/grpc/proto" +) + +func TestUnaryHandler_failure(t *testing.T) { + comp := proto.NewCompiler(nil) + fds, err := comp.Compile(context.Background(), []string{"./testdata/test.proto"}) + if err != nil { + t.Fatalf("failed to compile proto: %s", err) + } + svcName := protoreflect.FullName("scenarigo.testdata.test.Test") + sd, err := fds.ResolveService(svcName) + if err != nil { + t.Fatalf("failed to resovle service: %s", err) + } + md := sd.Methods().ByName("Echo") + + tests := map[string]struct { + mocks []protocol.Mock + svcName protoreflect.FullName + method protoreflect.MethodDescriptor + decode func(any) error + expect string + }{ + "no mock": { + expect: "failed to get mock: no mocks remain", + }, + "protocol must be grpc": { + mocks: []protocol.Mock{ + { + Protocol: "http", + }, + }, + expect: `received gRPC request but the mock protocol is "http"`, + }, + "failed to unmarshal expect": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage("-"), + }, + }, + expect: "failed to unmarshal: [1:1] string was used where mapping is expected", + }, + "invalid expect service": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(`service: '{{'`), + }, + }, + expect: ".expect.service: failed to build assretion: invalid expect service: failed to build assertion", + }, + "invalid expect method": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(`method: '{{'`), + }, + }, + expect: ".expect.method: failed to build assretion: invalid expect method: failed to build assertion", + }, + "invalid expect metadata": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage("metadata:\n foo: '{{'"), + }, + }, + expect: ".expect.metadata.foo: failed to build assretion: invalid expect metadata: failed to build assertion", + }, + "invalid expect message": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(`message: '{{'`), + }, + }, + expect: ".expect.message: failed to build assretion: invalid expect response message: failed to build assertion", + }, + "failed to decode message": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(""), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return errors.New("ERROR") }, + expect: ".expect.message: failed to decode message: ERROR", + }, + "assertion error": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage("message:\n messageId: '1'"), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return nil }, + expect: `request assertion failed: expected "1" but got ""`, + }, + "failed to unmarshal response": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(""), + Response: yamlutil.RawMessage("-"), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return nil }, + expect: ".response: failed to unmarshal response: [1:1] string was used where mapping is expected", + }, + "failed to execute template of response": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(""), + Response: yamlutil.RawMessage("message: '{{'"), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return nil }, + expect: ".response.message: failed to execute template of response", + }, + "invalid reponse status code": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(""), + Response: yamlutil.RawMessage("status:\n code: aaa"), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return nil }, + expect: ".response.status.code: invalid status code", + }, + "invalid reponse message": { + mocks: []protocol.Mock{ + { + Protocol: "grpc", + Expect: yamlutil.RawMessage(""), + Response: yamlutil.RawMessage("message:\n id: '1'"), + }, + }, + svcName: svcName, + method: md, + decode: func(_ any) error { return nil }, + expect: ".response.message: invalid message", + }, + } + + for name, test := range tests { + test := test + t.Run(name, func(t *testing.T) { + iter := protocol.NewMockIterator(test.mocks) + srv := &server{ + iter: iter, + } + ctx := context.Background() + if _, err := srv.unaryHandler(test.svcName, test.method)(nil, ctx, test.decode, nil); err == nil { + t.Fatal("no error") + } else if !strings.Contains(err.Error(), test.expect) { + t.Errorf("expect error %q but got %q", test.expect, err) + } + }) + } +} diff --git a/mock/protocol/grpc/health.go b/mock/protocol/grpc/health.go new file mode 100644 index 00000000..85e9d59c --- /dev/null +++ b/mock/protocol/grpc/health.go @@ -0,0 +1,24 @@ +package grpc + +import ( + "context" + + "google.golang.org/grpc" + healthpb "google.golang.org/grpc/health/grpc_health_v1" +) + +type healthServer struct{} + +// Check implements healthpb.HealthServer interface. +func (s *healthServer) Check(ctx context.Context, req *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVING, + }, nil +} + +// Watch implements healthpb.HealthServer interface. +func (s *healthServer) Watch(req *healthpb.HealthCheckRequest, streams grpc.ServerStreamingServer[healthpb.HealthCheckResponse]) error { + return streams.SendMsg(&healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVING, + }) +} diff --git a/mock/protocol/grpc/testdata/empty.yaml b/mock/protocol/grpc/testdata/empty.yaml new file mode 100644 index 00000000..fe51488c --- /dev/null +++ b/mock/protocol/grpc/testdata/empty.yaml @@ -0,0 +1 @@ +[] diff --git a/mock/protocol/grpc/testdata/grpc.yaml b/mock/protocol/grpc/testdata/grpc.yaml new file mode 100644 index 00000000..8317da64 --- /dev/null +++ b/mock/protocol/grpc/testdata/grpc.yaml @@ -0,0 +1,15 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Echo + metadata: + content-type: application/grpc + message: + messageId: '1' + messageBody: 'hello' + response: + status: + code: OK + message: + messageId: '1' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/int-status-code.yaml b/mock/protocol/grpc/testdata/int-status-code.yaml new file mode 100644 index 00000000..32342072 --- /dev/null +++ b/mock/protocol/grpc/testdata/int-status-code.yaml @@ -0,0 +1,10 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Echo + response: + status: + code: 0 + message: + messageId: '1' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/invalid-expect-message.yaml b/mock/protocol/grpc/testdata/invalid-expect-message.yaml new file mode 100644 index 00000000..6aeef159 --- /dev/null +++ b/mock/protocol/grpc/testdata/invalid-expect-message.yaml @@ -0,0 +1,9 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Echo + metadata: + content-type: application/grpc + message: + messageId: '2' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/invalid-expect-metadata.yaml b/mock/protocol/grpc/testdata/invalid-expect-metadata.yaml new file mode 100644 index 00000000..9bc21426 --- /dev/null +++ b/mock/protocol/grpc/testdata/invalid-expect-metadata.yaml @@ -0,0 +1,9 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Echo + metadata: + content-type: application/json + message: + messageId: '1' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/invalid-expect-method.yaml b/mock/protocol/grpc/testdata/invalid-expect-method.yaml new file mode 100644 index 00000000..9eea3312 --- /dev/null +++ b/mock/protocol/grpc/testdata/invalid-expect-method.yaml @@ -0,0 +1,9 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Foo + metadata: + content-type: application/grpc + message: + messageId: '1' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/invalid-expect-service.yaml b/mock/protocol/grpc/testdata/invalid-expect-service.yaml new file mode 100644 index 00000000..d3b8b550 --- /dev/null +++ b/mock/protocol/grpc/testdata/invalid-expect-service.yaml @@ -0,0 +1,9 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Foo + method: Echo + metadata: + content-type: application/grpc + message: + messageId: '1' + messageBody: 'hello' diff --git a/mock/protocol/grpc/testdata/no-expect-response.yaml b/mock/protocol/grpc/testdata/no-expect-response.yaml new file mode 100644 index 00000000..5fdb0808 --- /dev/null +++ b/mock/protocol/grpc/testdata/no-expect-response.yaml @@ -0,0 +1 @@ +- protocol: grpc diff --git a/mock/protocol/grpc/testdata/test.proto b/mock/protocol/grpc/testdata/test.proto new file mode 100644 index 00000000..80985ede --- /dev/null +++ b/mock/protocol/grpc/testdata/test.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package scenarigo.testdata.test; + +service Test { + rpc Echo(EchoRequest) returns (EchoResponse) {}; +} + +message EchoRequest { + string message_id = 1; + string message_body = 2; +} + +message EchoResponse { + string message_id = 1; + string message_body = 2; +} diff --git a/mock/protocol/grpc/testdata/unauthenticated.yaml b/mock/protocol/grpc/testdata/unauthenticated.yaml new file mode 100644 index 00000000..a5668bb7 --- /dev/null +++ b/mock/protocol/grpc/testdata/unauthenticated.yaml @@ -0,0 +1,7 @@ +- protocol: grpc + expect: + service: scenarigo.testdata.test.Test + method: Echo + response: + status: + code: Unauthenticated diff --git a/protocol/grpc/expect.go b/protocol/grpc/expect.go index 11720a20..764b9c9d 100644 --- a/protocol/grpc/expect.go +++ b/protocol/grpc/expect.go @@ -82,7 +82,7 @@ func (e *Expect) Build(ctx *context.Context) (assert.Assertion, error) { } return assert.AssertionFunc(func(v interface{}) error { - resp, ok := v.(response) + resp, ok := v.(*response) if !ok { return errors.Errorf(`failed to convert to response type. type is %s`, reflect.TypeOf(v)) } @@ -200,7 +200,7 @@ func (e *Expect) assertStatusDetails(assertions []assert.Assertion, sts *status. return nil } -func extract(v response) (proto.Message, *status.Status, error) { +func extract(v *response) (proto.Message, *status.Status, error) { vs := v.rvalues if len(vs) != 2 { return nil, nil, errors.Errorf("expected return value length of method call is 2 but %d", len(vs)) diff --git a/protocol/grpc/expect_test.go b/protocol/grpc/expect_test.go index b53306be..678274ec 100644 --- a/protocol/grpc/expect_test.go +++ b/protocol/grpc/expect_test.go @@ -16,6 +16,7 @@ import ( "github.com/zoncoen/scenarigo/context" "github.com/zoncoen/scenarigo/internal/reflectutil" + "github.com/zoncoen/scenarigo/internal/yamlutil" "github.com/zoncoen/scenarigo/testdata/gen/pb/test" ) @@ -24,11 +25,11 @@ func TestExpect_Build(t *testing.T) { tests := map[string]struct { vars interface{} expect *Expect - v response + v *response }{ "default": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{}), reflect.Zero(reflectutil.TypeError), @@ -39,7 +40,7 @@ func TestExpect_Build(t *testing.T) { expect: &Expect{ Code: strconv.Itoa(int(codes.InvalidArgument)), }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -50,7 +51,7 @@ func TestExpect_Build(t *testing.T) { expect: &Expect{ Code: "InvalidArgument", }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -61,7 +62,7 @@ func TestExpect_Build(t *testing.T) { expect: &Expect{ Code: `{{"InvalidArgument"}}`, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -82,7 +83,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{ MessageId: "1", @@ -102,8 +103,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Header: newMDMarshaler(metadata.MD{ + v: &response{ + Header: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -124,8 +125,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Trailer: newMDMarshaler(metadata.MD{ + v: &response{ + Trailer: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -161,7 +162,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -191,7 +192,7 @@ func TestExpect_Build(t *testing.T) { Message: `{{"invalid argument"}}`, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf( @@ -214,7 +215,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{ MessageId: "1", @@ -264,7 +265,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{ MessageBody: "hello", @@ -282,12 +283,12 @@ func TestExpect_Build(t *testing.T) { }, }).Err()), }, - Header: newMDMarshaler(metadata.MD{ + Header: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, }), - Trailer: newMDMarshaler(metadata.MD{ + Trailer: yamlutil.NewMDMarshaler(metadata.MD{ "version": []string{ "v1.0.0", }, @@ -315,7 +316,7 @@ func TestExpect_Build(t *testing.T) { t.Run("ng", func(t *testing.T) { tests := map[string]struct { expect *Expect - v response + v *response expectBuildError bool expectAssertError bool expectError string @@ -361,12 +362,12 @@ func TestExpect_Build(t *testing.T) { "return value must be []reflect.Value": { expect: &Expect{}, - v: response{}, + v: &response{}, expectAssertError: true, }, "the length of return values must be 2": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{}), }, @@ -375,7 +376,7 @@ func TestExpect_Build(t *testing.T) { }, "fist return value must be proto.Message": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -385,7 +386,7 @@ func TestExpect_Build(t *testing.T) { }, "second return value must be error": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{}), reflect.ValueOf(&test.EchoResponse{}), @@ -395,7 +396,7 @@ func TestExpect_Build(t *testing.T) { }, "wrong code in case of default": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -407,7 +408,7 @@ func TestExpect_Build(t *testing.T) { expect: &Expect{ Code: "OK", }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.New(codes.InvalidArgument, "invalid argument").Err()), @@ -429,7 +430,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.ValueOf(&test.EchoResponse{ MessageId: "1", @@ -450,8 +451,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Header: newMDMarshaler(metadata.MD{ + v: &response{ + Header: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -473,8 +474,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Header: newMDMarshaler(metadata.MD{ + v: &response{ + Header: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -496,8 +497,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Header: newMDMarshaler(metadata.MD{ + v: &response{ + Header: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -519,8 +520,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Trailer: newMDMarshaler(metadata.MD{ + v: &response{ + Trailer: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -542,8 +543,8 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ - Trailer: newMDMarshaler(metadata.MD{ + v: &response{ + Trailer: yamlutil.NewMDMarshaler(metadata.MD{ "content-type": []string{ "application/grpc", }, @@ -561,7 +562,7 @@ func TestExpect_Build(t *testing.T) { Code: "Invalid Argument", }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.Error(codes.NotFound, "not found")), @@ -576,7 +577,7 @@ func TestExpect_Build(t *testing.T) { Message: "foo", }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.Error(codes.NotFound, "not found")), @@ -601,7 +602,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -641,7 +642,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -682,7 +683,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -723,7 +724,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -764,7 +765,7 @@ func TestExpect_Build(t *testing.T) { }, }, }, - v: response{ + v: &response{ rvalues: []reflect.Value{ reflect.Zero(reflect.TypeOf(&test.EchoResponse{})), reflect.ValueOf(status.FromProto(&spb.Status{ @@ -838,9 +839,9 @@ func TestExpect_Build(t *testing.T) { expect: &Expect{}, v: "string is unexpected value", }, - "invalid type for rvalues of response": { + "invalid type for rvalues of *response": { expect: &Expect{}, - v: response{ + v: &response{ rvalues: []reflect.Value{}, }, }, diff --git a/protocol/grpc/request.go b/protocol/grpc/request.go index 5d631228..a0c8f738 100644 --- a/protocol/grpc/request.go +++ b/protocol/grpc/request.go @@ -1,39 +1,37 @@ package grpc import ( - "bytes" - "encoding/hex" "fmt" "reflect" "strings" - "unicode/utf8" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" "github.com/goccy/go-yaml" "github.com/zoncoen/scenarigo/context" "github.com/zoncoen/scenarigo/errors" "github.com/zoncoen/scenarigo/internal/queryutil" - "github.com/zoncoen/scenarigo/internal/reflectutil" + "github.com/zoncoen/scenarigo/internal/yamlutil" ) // Request represents a request. type Request struct { - Client string `yaml:"client,omitempty"` - Method string `yaml:"method"` - Metadata interface{} `yaml:"metadata,omitempty"` - Message interface{} `yaml:"message,omitempty"` + Client string `yaml:"client,omitempty"` + Target string `yaml:"target,omitempty"` + Service string `yaml:"service,omitempty"` + Method string `yaml:"method"` + Metadata interface{} `yaml:"metadata,omitempty"` + Message interface{} `yaml:"message,omitempty"` + Options *RequestOptions `yaml:"options,omitempty"` // for backward compatibility Body interface{} `yaml:"body,omitempty"` } +// RequestOptions represents request options. +type RequestOptions struct { + Reflection bool `yaml:"reflection,omitempty"` +} + // RequestExtractor represents a request dump. type RequestExtractor Request @@ -51,11 +49,11 @@ func (r RequestExtractor) ExtractByKey(key string) (interface{}, bool) { } type response struct { - Status responseStatus `yaml:"status,omitempty"` - Header *mdMarshaler `yaml:"header,omitempty"` - Trailer *mdMarshaler `yaml:"trailer,omitempty"` - Message interface{} `yaml:"message,omitempty"` - rvalues []reflect.Value `yaml:"-"` + Status responseStatus `yaml:"status,omitempty"` + Header *yamlutil.MDMarshaler `yaml:"header,omitempty"` + Trailer *yamlutil.MDMarshaler `yaml:"trailer,omitempty"` + Message interface{} `yaml:"message,omitempty"` + rvalues []reflect.Value `yaml:"-"` } type responseStatus struct { @@ -80,31 +78,6 @@ func (r ResponseExtractor) ExtractByKey(key string) (interface{}, bool) { return nil, false } -func newMDMarshaler(md metadata.MD) *mdMarshaler { return (*mdMarshaler)(&md) } - -type mdMarshaler metadata.MD - -func (m *mdMarshaler) MarshalYAML() ([]byte, error) { - mp := make(metadata.MD, len(*m)) - for k, vs := range *m { - vs := vs - if !strings.HasSuffix(k, "-bin") { - mp[k] = vs - continue - } - s := make([]string, len(vs)) - for i, v := range vs { - v := v - if !utf8.ValidString(v) { - v = hex.EncodeToString([]byte(v)) - } - s[i] = v - } - mp[k] = s - } - return yaml.Marshal(mp) -} - const ( indentNum = 2 ) @@ -124,6 +97,11 @@ func (r *Request) addIndent(s string, indentNum int) string { // Invoke implements protocol.Invoker interface. func (r *Request) Invoke(ctx *context.Context) (*context.Context, interface{}, error) { + opts := &RequestOptions{} + if r.Options != nil { + opts = r.Options + } + if r.Client == "" { return ctx, nil, errors.New("gRPC client must be specified") } @@ -132,203 +110,13 @@ func (r *Request) Invoke(ctx *context.Context) (*context.Context, interface{}, e if err != nil { return ctx, nil, errors.WrapPath(err, "client", "failed to get client") } - - client := reflect.ValueOf(x) - var method reflect.Value - for { - if !client.IsValid() { - return nil, nil, errors.ErrorPathf("client", "client %s is invalid", r.Client) - } - method = client.MethodByName(r.Method) - if method.IsValid() { - // method found - break - } - switch client.Kind() { - case reflect.Interface, reflect.Ptr: - client = client.Elem() - default: - return nil, nil, errors.ErrorPathf("method", "method %s.%s not found", r.Client, r.Method) - } - } - - if err := validateMethod(method); err != nil { - return ctx, nil, errors.ErrorPathf("method", `"%s.%s" must be "func(context.Context, proto.Message, ...grpc.CallOption) (proto.Message, error): %s"`, r.Client, r.Method, err) - } - - return invoke(ctx, method, r) -} - -func validateMethod(method reflect.Value) error { - if !method.IsValid() { - return errors.New("invalid") - } - if method.Kind() != reflect.Func { - return errors.New("not function") - } - if method.IsNil() { - return errors.New("method is nil") - } - - mt := method.Type() - if n := mt.NumIn(); n != 3 { - return errors.Errorf("number of arguments must be 3 but got %d", n) - } - if t := mt.In(0); !t.Implements(typeContext) { - return errors.Errorf("first argument must be context.Context but got %s", t.String()) - } - if t := mt.In(1); !t.Implements(typeMessage) { - return errors.Errorf("second argument must be proto.Message but got %s", t.String()) - } - if t := mt.In(2); t != typeCallOpts { - return errors.Errorf("third argument must be []grpc.CallOption but got %s", t.String()) - } - if n := mt.NumOut(); n != 2 { - return errors.Errorf("number of return values must be 2 but got %d", n) - } - if t := mt.Out(0); !t.Implements(typeMessage) { - return errors.Errorf("first return value must be proto.Message but got %s", t.String()) - } - if t := mt.Out(1); !t.Implements(reflectutil.TypeError) { - return errors.Errorf("second return value must be error but got %s", t.String()) + var client serviceClient = &customServiceClient{ + v: reflect.ValueOf(x), } - return nil + return client.invoke(ctx, r, opts) } -func invoke(ctx *context.Context, method reflect.Value, r *Request) (*context.Context, interface{}, error) { - reqCtx := ctx.RequestContext() - if r.Metadata != nil { - x, err := ctx.ExecuteTemplate(r.Metadata) - if err != nil { - return ctx, nil, errors.WrapPathf(err, "metadata", "failed to set metadata") - } - md, err := reflectutil.ConvertStringsMap(reflect.ValueOf(x)) - if err != nil { - return nil, nil, errors.WrapPathf(err, "metadata", "failed to set metadata") - } - - pairs := []string{} - for k, vs := range md { - vs := vs - for _, v := range vs { - pairs = append(pairs, k, v) - } - } - reqCtx = metadata.AppendToOutgoingContext(reqCtx, pairs...) - } - - var in []reflect.Value - for i := 0; i < method.Type().NumIn(); i++ { - switch i { - case 0: - in = append(in, reflect.ValueOf(reqCtx)) - case 1: - req := reflect.New(method.Type().In(i).Elem()).Interface() - if err := buildRequestMsg(ctx, req, r.Message); err != nil { - return ctx, nil, errors.WrapPathf(err, "message", "failed to build request message") - } - - //nolint:exhaustruct - dumpReq := &Request{ - Method: r.Method, - Message: req, - } - reqMD, _ := metadata.FromOutgoingContext(reqCtx) - if len(reqMD) > 0 { - dumpReq.Metadata = newMDMarshaler(reqMD) - } - ctx = ctx.WithRequest((*RequestExtractor)(dumpReq)) - if b, err := yaml.Marshal(dumpReq); err == nil { - ctx.Reporter().Logf("request:\n%s", r.addIndent(string(b), indentNum)) - } else { - ctx.Reporter().Logf("failed to dump request:\n%s", err) - } - - in = append(in, reflect.ValueOf(req)) - } - } - - var header, trailer metadata.MD - in = append(in, - reflect.ValueOf(grpc.Header(&header)), - reflect.ValueOf(grpc.Trailer(&trailer)), - ) - - rvalues := method.Call(in) - message := rvalues[0].Interface() - var err error - if rvalues[1].IsValid() && rvalues[1].CanInterface() { - e, ok := rvalues[1].Interface().(error) - if ok { - err = e - } - } - resp := response{ - Status: responseStatus{ - Code: codes.OK.String(), - Message: "", - Details: nil, - }, - Message: message, - rvalues: rvalues, - } - if len(header) > 0 { - resp.Header = newMDMarshaler(header) - } - if len(trailer) > 0 { - resp.Trailer = newMDMarshaler(trailer) - } - if err != nil { - if sts, ok := status.FromError(err); ok { - resp.Status.Code = sts.Code().String() - resp.Status.Message = sts.Message() - details := sts.Details() - if l := len(details); l > 0 { - m := make(yaml.MapSlice, l) - for i, d := range details { - item := yaml.MapItem{ - Key: "", - Value: d, - } - if msg, ok := d.(proto.Message); ok { - item.Key = string(proto.MessageName(msg)) - } else { - item.Key = fmt.Sprintf("%T (not proto.Message)", d) - } - m[i] = item - } - resp.Status.Details = m - } - } - } - ctx = ctx.WithResponse((*ResponseExtractor)(&resp)) - if b, err := yaml.Marshal(resp); err == nil { - ctx.Reporter().Logf("response:\n%s", r.addIndent(string(b), indentNum)) - } else { - ctx.Reporter().Logf("failed to dump response:\n%s", err) - } - - return ctx, resp, nil -} - -func buildRequestMsg(ctx *context.Context, req interface{}, src interface{}) error { - x, err := ctx.ExecuteTemplate(src) - if err != nil { - return err - } - if x == nil { - return nil - } - var buf bytes.Buffer - if err := yaml.NewEncoder(&buf, yaml.JSON()).Encode(x); err != nil { - return err - } - message, ok := req.(proto.Message) - if ok { - if err := protojson.Unmarshal(buf.Bytes(), message); err != nil { - return err - } - } - return nil +type serviceClient interface { + invoke(*context.Context, *Request, *RequestOptions) (*context.Context, *response, error) } diff --git a/protocol/grpc/request_custom_client.go b/protocol/grpc/request_custom_client.go new file mode 100644 index 00000000..2fc4ed1d --- /dev/null +++ b/protocol/grpc/request_custom_client.go @@ -0,0 +1,228 @@ +package grpc + +import ( + "bytes" + "fmt" + "reflect" + + "github.com/goccy/go-yaml" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + "github.com/zoncoen/scenarigo/context" + "github.com/zoncoen/scenarigo/errors" + "github.com/zoncoen/scenarigo/internal/reflectutil" + "github.com/zoncoen/scenarigo/internal/yamlutil" +) + +type customServiceClient struct { + v reflect.Value +} + +func (client *customServiceClient) invoke(ctx *context.Context, r *Request, opts *RequestOptions) (*context.Context, *response, error) { + var method reflect.Value + for { + if !client.v.IsValid() { + return ctx, nil, errors.ErrorPathf("client", "client %s is invalid", r.Client) + } + method = client.v.MethodByName(r.Method) + if method.IsValid() { + // method found + break + } + switch client.v.Kind() { + case reflect.Interface, reflect.Ptr: + client.v = client.v.Elem() + default: + return ctx, nil, errors.ErrorPathf("method", "method %s.%s not found", r.Client, r.Method) + } + } + + if err := validateMethod(method); err != nil { + return ctx, nil, errors.ErrorPathf("method", `"%s.%s" must be "func(context.Context, proto.Message, ...grpc.CallOption) (proto.Message, error): %s"`, r.Client, r.Method, err) + } + + return invoke(ctx, method, r) +} + +func validateMethod(method reflect.Value) error { + if !method.IsValid() { + return errors.New("invalid") + } + if method.Kind() != reflect.Func { + return errors.New("not function") + } + if method.IsNil() { + return errors.New("method is nil") + } + + mt := method.Type() + if n := mt.NumIn(); n != 3 { + return errors.Errorf("number of arguments must be 3 but got %d", n) + } + if t := mt.In(0); !t.Implements(typeContext) { + return errors.Errorf("first argument must be context.Context but got %s", t.String()) + } + if t := mt.In(1); !t.Implements(typeMessage) { + return errors.Errorf("second argument must be proto.Message but got %s", t.String()) + } + if t := mt.In(2); t != typeCallOpts { + return errors.Errorf("third argument must be []grpc.CallOption but got %s", t.String()) + } + if n := mt.NumOut(); n != 2 { + return errors.Errorf("number of return values must be 2 but got %d", n) + } + if t := mt.Out(0); !t.Implements(typeMessage) { + return errors.Errorf("first return value must be proto.Message but got %s", t.String()) + } + if t := mt.Out(1); !t.Implements(reflectutil.TypeError) { + return errors.Errorf("second return value must be error but got %s", t.String()) + } + + return nil +} + +func invoke(ctx *context.Context, method reflect.Value, r *Request) (*context.Context, *response, error) { + reqCtx := ctx.RequestContext() + if r.Metadata != nil { + x, err := ctx.ExecuteTemplate(r.Metadata) + if err != nil { + return ctx, nil, errors.WrapPathf(err, "metadata", "failed to set metadata") + } + md, err := reflectutil.ConvertStringsMap(reflect.ValueOf(x)) + if err != nil { + return nil, nil, errors.WrapPathf(err, "metadata", "failed to set metadata") + } + + pairs := []string{} + for k, vs := range md { + for _, v := range vs { + pairs = append(pairs, k, v) + } + } + reqCtx = metadata.AppendToOutgoingContext(reqCtx, pairs...) + } + + var in []reflect.Value + for i := 0; i < method.Type().NumIn(); i++ { + switch i { + case 0: + in = append(in, reflect.ValueOf(reqCtx)) + case 1: + req := reflect.New(method.Type().In(i).Elem()).Interface() + if err := buildRequestMsg(ctx, req, r.Message); err != nil { + return ctx, nil, errors.WrapPathf(err, "message", "failed to build request message") + } + + //nolint:exhaustruct + dumpReq := &Request{ + Method: r.Method, + Message: req, + } + reqMD, _ := metadata.FromOutgoingContext(reqCtx) + if len(reqMD) > 0 { + dumpReq.Metadata = yamlutil.NewMDMarshaler(reqMD) + } + ctx = ctx.WithRequest((*RequestExtractor)(dumpReq)) + if b, err := yaml.Marshal(dumpReq); err == nil { + ctx.Reporter().Logf("request:\n%s", r.addIndent(string(b), indentNum)) + } else { + ctx.Reporter().Logf("failed to dump request:\n%s", err) + } + + in = append(in, reflect.ValueOf(req)) + } + } + + var header, trailer metadata.MD + in = append(in, + reflect.ValueOf(grpc.Header(&header)), + reflect.ValueOf(grpc.Trailer(&trailer)), + ) + + rvalues := method.Call(in) + message := rvalues[0].Interface() + var err error + if rvalues[1].IsValid() && rvalues[1].CanInterface() { + e, ok := rvalues[1].Interface().(error) + if ok { + err = e + } + } + resp := &response{ + Status: responseStatus{ + Code: codes.OK.String(), + Message: "", + Details: nil, + }, + Message: message, + rvalues: rvalues, + } + if len(header) > 0 { + resp.Header = yamlutil.NewMDMarshaler(header) + } + if len(trailer) > 0 { + resp.Trailer = yamlutil.NewMDMarshaler(trailer) + } + if err != nil { + if sts, ok := status.FromError(err); ok { + resp.Status.Code = sts.Code().String() + resp.Status.Message = sts.Message() + details := sts.Details() + if l := len(details); l > 0 { + m := make(yaml.MapSlice, l) + for i, d := range details { + item := yaml.MapItem{ + Key: "", + Value: d, + } + if msg, ok := d.(proto.Message); ok { + item.Key = string(proto.MessageName(msg)) + } else { + item.Key = fmt.Sprintf("%T (not proto.Message)", d) + } + m[i] = item + } + resp.Status.Details = m + } + } + } + ctx = ctx.WithResponse((*ResponseExtractor)(resp)) + if b, err := yaml.Marshal(resp); err == nil { + ctx.Reporter().Logf("response:\n%s", r.addIndent(string(b), indentNum)) + } else { + ctx.Reporter().Logf("failed to dump response:\n%s", err) + } + + return ctx, resp, nil +} + +func buildRequestMsg(ctx *context.Context, req interface{}, src interface{}) error { + x, err := ctx.ExecuteTemplate(src) + if err != nil { + return err + } + if x == nil { + return nil + } + msg, ok := req.(proto.Message) + if !ok { + return fmt.Errorf("expect proto.Message but got %T", req) + } + return ConvertToProto(x, msg) +} + +func ConvertToProto(v any, msg proto.Message) error { + var buf bytes.Buffer + if err := yaml.NewEncoder(&buf, yaml.JSON()).Encode(v); err != nil { + return err + } + if err := protojson.Unmarshal(buf.Bytes(), msg); err != nil { + return err + } + return nil +} diff --git a/protocol/grpc/request_test.go b/protocol/grpc/request_test.go index 1ee8ba00..c3ff0808 100644 --- a/protocol/grpc/request_test.go +++ b/protocol/grpc/request_test.go @@ -17,6 +17,7 @@ import ( "github.com/zoncoen/scenarigo/internal/mockutil" "github.com/zoncoen/scenarigo/internal/queryutil" "github.com/zoncoen/scenarigo/internal/testutil" + "github.com/zoncoen/scenarigo/internal/yamlutil" "github.com/zoncoen/scenarigo/reporter" testpb "github.com/zoncoen/scenarigo/testdata/gen/pb/test" "google.golang.org/genproto/googleapis/rpc/errdetails" @@ -97,10 +98,10 @@ func TestResponseExtractor(t *testing.T) { Status: responseStatus{ Code: "OK", }, - Header: &mdMarshaler{ + Header: &yamlutil.MDMarshaler{ "foo": []string{"FOO"}, }, - Trailer: &mdMarshaler{ + Trailer: &yamlutil.MDMarshaler{ "bar": []string{"BAR"}, }, Message: map[string]string{"messageBody": "hey"}, @@ -189,9 +190,9 @@ func TestRequest_Invoke(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - typedResult, ok := result.(response) + typedResult, ok := result.(*response) if !ok { - t.Fatalf("failed to type conversion from %s to response", reflect.TypeOf(result)) + t.Fatalf("failed to type conversion from %s to *response", reflect.TypeOf(result)) } message, serr, err := extract(typedResult) if err != nil { @@ -210,7 +211,7 @@ func TestRequest_Invoke(t *testing.T) { if diff := cmp.Diff((*RequestExtractor)(r), ctx.Request(), protocmp.Transform()); diff != "" { t.Errorf("differs: (-want +got)\n%s", diff) } - if diff := cmp.Diff((*ResponseExtractor)(&typedResult), ctx.Response(), protocmp.Transform(), cmpopts.IgnoreFields(ResponseExtractor{}, "rvalues")); diff != "" { + if diff := cmp.Diff((*ResponseExtractor)(typedResult), ctx.Response(), protocmp.Transform(), cmpopts.IgnoreFields(ResponseExtractor{}, "rvalues")); diff != "" { t.Errorf("differs: (-want +got)\n%s", diff) } }) @@ -237,9 +238,9 @@ func TestRequest_Invoke(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - typedResult, ok := result.(response) + typedResult, ok := result.(*response) if !ok { - t.Fatalf("failed to type conversion from %s to response", reflect.TypeOf(result)) + t.Fatalf("failed to type conversion from %s to *response", reflect.TypeOf(result)) } _, serr, err := extract(typedResult) if err != nil { @@ -625,7 +626,7 @@ metadata: t.Run(name, func(t *testing.T) { b, err := yaml.Marshal(Request{ Method: "Foo", - Metadata: newMDMarshaler(test.md), + Metadata: yamlutil.NewMDMarshaler(test.md), }) if err != nil { t.Fatal(err)