Skip to content

Commit

Permalink
feat(mock): implement grpc protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
zoncoen committed Oct 27, 2024
1 parent 37d651a commit 0b86de6
Show file tree
Hide file tree
Showing 23 changed files with 1,380 additions and 293 deletions.
8 changes: 8 additions & 0 deletions assert/nop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package assert

// Nop returns an assertion that does not assert anything.
func Nop() Assertion {
return AssertionFunc(func(_ interface{}) error {
return nil
})
}
65 changes: 65 additions & 0 deletions internal/queryutil/query.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package queryutil

import (
"context"
"reflect"
"strings"
"sync"

query "github.com/zoncoen/query-go"
yamlextractor "github.com/zoncoen/query-go/extractor/yaml"
"google.golang.org/protobuf/types/dynamicpb"
)

var (
Expand All @@ -23,11 +27,72 @@ func Options() []query.Option {
[]query.Option{
query.ExtractByStructTag("yaml", "json"),
query.CustomExtractFunc(yamlextractor.MapSliceExtractFunc()),
query.CustomExtractFunc(dynamicpbExtractFunc()),
},
opts...,
)
}

func dynamicpbExtractFunc() func(query.ExtractFunc) query.ExtractFunc {
return func(f query.ExtractFunc) query.ExtractFunc {
return func(in reflect.Value) (reflect.Value, bool) {
v := in
if v.IsValid() && v.CanInterface() {
if msg, ok := v.Interface().(*dynamicpb.Message); ok {
return f(reflect.ValueOf(&keyExtractor{
v: msg,
}))
}
}
return f(in)
}
}
}

type keyExtractor struct {
v *dynamicpb.Message
}

// ExtractByKey implements the query.KeyExtractorContext interface.
func (e *keyExtractor) ExtractByKey(ctx context.Context, key string) (interface{}, bool) {
ci := query.IsCaseInsensitive(ctx)
if ci {
key = strings.ToLower(key)
}
fields := e.v.Descriptor().Fields()
for i := range fields.Len() {
f := fields.Get(i)
{
name := string(f.Name())
if ci {
name = strings.ToLower(name)
}
if name == key {
return e.v.Get(f).Interface(), true
}
}
{
name := f.TextName()
if ci {
name = strings.ToLower(name)
}
if name == key {
return e.v.Get(f).Interface(), true
}
}
if f.HasJSONName() {
name := string(f.JSONName())
if ci {
name = strings.ToLower(name)
}
if name == key {
return e.v.Get(f).Interface(), true
}
}
}
return nil, false
}

func AppendOptions(customOpts ...query.Option) {
m.Lock()
defer m.Unlock()
Expand Down
34 changes: 34 additions & 0 deletions internal/yamlutil/grpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package yamlutil

import (
"encoding/hex"
"strings"
"unicode/utf8"

"google.golang.org/grpc/metadata"

"github.com/goccy/go-yaml"
)

func NewMDMarshaler(md metadata.MD) *MDMarshaler { return (*MDMarshaler)(&md) }

type MDMarshaler metadata.MD

func (m *MDMarshaler) MarshalYAML() ([]byte, error) {
mp := make(metadata.MD, len(*m))
for k, vs := range *m {

Check failure on line 19 in internal/yamlutil/grpc.go

View workflow job for this annotation

GitHub Actions / golangci

[golangci] internal/yamlutil/grpc.go#L19

cannot range over *m (variable of type MDMarshaler) (typecheck)
Raw output
internal/yamlutil/grpc.go:19:21: cannot range over *m (variable of type MDMarshaler) (typecheck)
	for k, vs := range *m {
	                   ^
if !strings.HasSuffix(k, "-bin") {
mp[k] = vs
continue
}
s := make([]string, len(vs))
for i, v := range vs {
if !utf8.ValidString(v) {
v = hex.EncodeToString([]byte(v))
}
s[i] = v
}
mp[k] = s
}
return yaml.Marshal(mp)
}
205 changes: 205 additions & 0 deletions mock/protocol/grpc/grpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package grpc

import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/goccy/go-yaml"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"

"github.com/zoncoen/scenarigo/logger"
"github.com/zoncoen/scenarigo/mock/protocol"
"github.com/zoncoen/scenarigo/protocol/grpc/proto"
)

var waitInterval = 100 * time.Millisecond

// Register registers grpc protocol.
func Register() {
protocol.Register(&GRPC{})
}

// GRPC is a protocol type for the mock.
type GRPC struct{}

// Name implements protocol.Protocol interface.
func (_ GRPC) Name() string { return "grpc" } //nolint:revive

// UnmarshalConfig implements protocol.Protocol interface.
func (_ GRPC) UnmarshalConfig(b []byte) (interface{}, error) { //nolint:revive
var config ServerConfig
if err := yaml.Unmarshal(b, &config); err != nil {
return nil, err
}
return &config, nil
}

// NewServer implements protocol.Protocol interface.
func (_ *GRPC) NewServer(iter *protocol.MockIterator, l logger.Logger, config interface{}) (protocol.Server, error) { //nolint:revive
if iter == nil {
return nil, errors.New("mock iterator is nil")
}
cfg, ok := config.(*ServerConfig)
if !ok {
return nil, fmt.Errorf("invalid config %T", config)
}
//nolint:exhaustruct
srv := &server{
iter: iter,
}
if cfg != nil {
srv.config = *cfg
comp := proto.NewCompiler(cfg.Proto.Imports)
fds, err := comp.Compile(context.Background(), cfg.Proto.Files)
if err != nil {
return nil, fmt.Errorf("failed to compile proto: %w", err)
}
srv.resolver = fds
}
return srv, nil
}

// ServerConfig represents a server configuration.
type ServerConfig struct {
Port int `yaml:"port,omitempty"`
Proto ProtoConfig `yaml:"proto,omitempty"`
}

// ProtoConfig represents a proto configuration.
type ProtoConfig struct {
Imports []string `yaml:"imports,omitempty"`
Files []string `yaml:"files,omitempty"`
}

type server struct {
m sync.Mutex
config ServerConfig
iter *protocol.MockIterator
resolver proto.ServiceDescriptorResolver
addr string
srv *grpc.Server
}

// Start implements protocol.Server interface.
func (s *server) Start(ctx context.Context) error {
s.m.Lock()
serve, err := s.setup()
if err != nil {
s.m.Unlock()
return err
}
s.m.Unlock()
return serve()
}

func (s *server) setup() (func() error, error) {
if s.srv != nil {
return nil, errors.New("server already started")
}
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", s.config.Port))
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
s.addr = ln.Addr().String()
s.srv = grpc.NewServer()
healthpb.RegisterHealthServer(s.srv, &healthServer{})
names, err := s.resolver.ListServices()
if err != nil {
return nil, fmt.Errorf("failed to get service descriptor: %w", err)
}
for _, name := range names {
sd, err := s.resolver.ResolveService(name)
if err != nil {
return nil, fmt.Errorf("failed to get service descriptor: %w", err)
}
s.srv.RegisterService(s.convertToServicDesc(sd), nil)
}
return func() error {
if err := s.srv.Serve(ln); err != nil {
if !errors.Is(err, grpc.ErrServerStopped) {
return err
}
}
return nil
}, nil
}

// Wait implements protocol.Server interface.
func (s *server) Wait(ctx context.Context) error {
ch := make(chan error)
go func() {
ch <- s.wait(ctx)
}()
select {
case <-ctx.Done():
return context.Canceled
case err := <-ch:
return err
}
}

func (s *server) wait(ctx context.Context) error {
var once sync.Once
var client healthpb.HealthClient
for {
if err := ctx.Err(); err != nil {
return err
}
s.m.Lock()
srv := s.srv
s.m.Unlock()
if srv != nil {
var err error
once.Do(func() {
c, cErr := grpc.NewClient(s.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if cErr != nil {
err = fmt.Errorf("failed to connect server: %w", err)
return
}
client = healthpb.NewHealthClient(c)
})
if err != nil {
return err
}
resp, err := client.Check(ctx, &healthpb.HealthCheckRequest{
Service: "grpc.health.v1",
})
if err == nil {
if resp.GetStatus() == healthpb.HealthCheckResponse_SERVING {
return nil
}
}
}
time.Sleep(waitInterval)
}
}

// Stop implements protocol.Server interface.
func (s *server) Stop(ctx context.Context) error {
s.m.Lock()
defer s.m.Unlock()
if s.srv == nil {
return protocol.ErrServerClosed
}
s.addr = ""
srv := s.srv
s.srv = nil
srv.GracefulStop() // GracefulStop() calls s.ln.Close()
return nil
}

// Addr implements protocol.Server interface.
func (s *server) Addr() (string, error) {
s.m.Lock()
defer s.m.Unlock()
if s.srv == nil {
return "", protocol.ErrServerClosed
}
return s.addr, nil
}
Loading

0 comments on commit 0b86de6

Please sign in to comment.