diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..692e351 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +.PHONY: proto +proto: install + protoc --go-frpc_out=./pkg/generator ./examples/test/test.proto + +.PHONY: install +install: + go install ./protoc-gen-go-frpc + +.PHONY: test +test: proto + go test -v ./... diff --git a/go.mod b/go.mod index 61b2699..e726bca 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/loopholelabs/frpc-go go 1.22 require ( - github.com/loopholelabs/frisbee-go v0.10.0 + github.com/loopholelabs/frisbee-go v0.10.2 github.com/loopholelabs/logging v0.3.1 - github.com/loopholelabs/polyglot/v2 v2.0.2 + github.com/loopholelabs/polyglot/v2 v2.0.3 github.com/loopholelabs/testing v0.2.3 github.com/stretchr/testify v1.9.0 google.golang.org/protobuf v1.35.1 @@ -13,7 +13,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/loopholelabs/common v0.4.9 // indirect + github.com/loopholelabs/common v0.4.10 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 98b13a4..889e82b 100644 --- a/go.sum +++ b/go.sum @@ -2,18 +2,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/loopholelabs/common v0.4.9 h1:9MPUYlZZ/qx3Kt8LXgXxcSXthrM91od8026c4DlGpAU= -github.com/loopholelabs/common v0.4.9/go.mod h1:Wop5srN1wYT+mdQ9gZ+kn2I9qKAyVd0FB48pThwIa9M= -github.com/loopholelabs/frisbee-go v0.10.0 h1:eqdDqm44V23GMxhjDL9OBz2Fxecsu42M0KHM8kQObBQ= -github.com/loopholelabs/frisbee-go v0.10.0/go.mod h1:RwKglItbNcQq9UW6Vm2aOwOWXR6wTxUNjvuXcCIyJkU= +github.com/loopholelabs/common v0.4.10 h1:BMJSMwH0PiVtdpOlXNPlW827B9WPJ/Gkb/q20NLeOjw= +github.com/loopholelabs/common v0.4.10/go.mod h1:wc17hLpzZaDbndb7Fh3MXQDnhf4Cmf/JKC+LmXaD6II= +github.com/loopholelabs/frisbee-go v0.10.2 h1:TvfONSpoCrrvG7/8QzBbHou2gPqOwN5iFL+4fp3bu4s= +github.com/loopholelabs/frisbee-go v0.10.2/go.mod h1:JJhInJ5zjxpwOLAqEviOjTo1k8E3NNEF/fQTrC4lZ/4= github.com/loopholelabs/logging v0.3.1 h1:VA9DF3WrbmvJC1uQJ/XcWgz8KWXydWwe3BdDiMbN2FY= github.com/loopholelabs/logging v0.3.1/go.mod h1:uRDUydiqPqKbZkb0WoQ3dfyAcJ2iOMhxdEafZssLVv0= -github.com/loopholelabs/polyglot/v2 v2.0.2 h1:v308fg2ZKSvkKDnWgBnDvvmiu4YypCxcDe5Ih5GUVnY= -github.com/loopholelabs/polyglot/v2 v2.0.2/go.mod h1:kFoSKvnKAWmV0ICfbaCHDv/+cz5LSuA+xXG4WtYV/z4= +github.com/loopholelabs/polyglot/v2 v2.0.3 h1:CpH2az5shkOgOBASnzjc1XC5SVzQMbWyHt4R7ds/FFc= +github.com/loopholelabs/polyglot/v2 v2.0.3/go.mod h1:yodgE9ile0RS/npD0WnHfFpMLvL5FlC9n3qZ1tTkB9g= github.com/loopholelabs/testing v0.2.3 h1:4nVuK5ctaE6ua5Z0dYk2l7xTFmcpCYLUeGjRBp8keOA= github.com/loopholelabs/testing v0.2.3/go.mod h1:gqtGY91soYD1fQoKQt/6kP14OYpS7gcbcIgq5mc9m8Q= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -22,6 +22,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= @@ -35,7 +37,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/generator/test/generator_test.go b/pkg/generator/test/generator_test.go index df2770c..eb95c04 100644 --- a/pkg/generator/test/generator_test.go +++ b/pkg/generator/test/generator_test.go @@ -4,11 +4,18 @@ package test import ( "context" + "fmt" "io" + "net/http" + "net/http/httptest" + "os" "testing" + "time" + "github.com/loopholelabs/polyglot/v2" "github.com/loopholelabs/testing/conn/pair" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestRPC(t *testing.T) { @@ -130,3 +137,51 @@ func testClientStreaming(client *Client, t *testing.T) { assert.NoError(t, err) assert.Equal(t, "Hello World", res.Message) } + +func TestRPCInvalidConnection(t *testing.T) { + // Create non-Frisbee server the client can connect to but not exchange + // messages, so the connection will be broken soon after connect. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "test") + })) + t.Cleanup(ts.Close) + + // Create a client and connect to test server. + client, err := NewClient(nil, nil) + require.NoError(t, err) + + err = client.Connect(ts.Listener.Addr().String()) + require.NoError(t, err) + + // Make RPC request with a 3s timeout. + timeout := 3 * time.Second + if os.Getenv("CI") != "" { + timeout = 10 * time.Second + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + req := &Request{ + Message: "Hello World", + Corpus: RequestUNIVERSAL, + } + response, err := client.EchoService.Echo(ctx, req) + + // Verify request doesn't block forever. + require.NoError(t, ctx.Err()) + + // Verify request fails. + require.Error(t, err) + require.Nil(t, response) +} + +func TestEncodeDecodePreservesNilFields(t *testing.T) { + r := &Response{Message: "test", Test: nil} + b := polyglot.NewBuffer() + r.Encode(b) + + got := &Response{} + err := got.Decode(b.Bytes()) + require.NoError(t, err) + require.Equal(t, r, got) +} diff --git a/pkg/generator/test/test.frpc.go b/pkg/generator/test/test.frpc.go index a924719..65b7420 100644 --- a/pkg/generator/test/test.frpc.go +++ b/pkg/generator/test/test.frpc.go @@ -1,22 +1,22 @@ // Code generated by fRPC Go v0.10.0, DO NOT EDIT. -// source: test.proto +// source: examples/test/test.proto package test import ( - "context" "errors" - "github.com/loopholelabs/polyglot/v2" "net" "sync" + "context" + "github.com/loopholelabs/polyglot/v2" "crypto/tls" "github.com/loopholelabs/frisbee-go" "github.com/loopholelabs/frisbee-go/pkg/packet" "github.com/loopholelabs/logging/types" - "io" "sync/atomic" + "io" ) var ( @@ -114,9 +114,7 @@ type Response struct { } func NewResponse() *Response { - return &Response{ - Test: NewData(), - } + return &Response{} } func (x *Response) Error(b *polyglot.Buffer, err error) { @@ -162,12 +160,12 @@ func (x *Response) decode(d *polyglot.BufferDecoder) error { if err != nil { return err } - if x.Test == nil { + if !d.Nil() { x.Test = NewData() - } - err = x.Test.decode(d) - if err != nil { - return err + err = x.Test.decode(d) + if err != nil { + return err + } } return nil } @@ -646,9 +644,7 @@ type SomeOtherMessage struct { } func NewSomeOtherMessage() *SomeOtherMessage { - return &SomeOtherMessage{ - Result: NewSearchResponseResult(), - } + return &SomeOtherMessage{} } func (x *SomeOtherMessage) Error(b *polyglot.Buffer, err error) { @@ -690,12 +686,12 @@ func (x *SomeOtherMessage) decode(d *polyglot.BufferDecoder) error { if err != nil { return err } - if x.Result == nil { + if !d.Nil() { x.Result = NewSearchResponseResult() - } - err = x.Result.decode(d) - if err != nil { - return err + err = x.Result.decode(d) + if err != nil { + return err + } } return nil } @@ -769,9 +765,7 @@ type OuterMiddleAA struct { } func NewOuterMiddleAA() *OuterMiddleAA { - return &OuterMiddleAA{ - Inner: NewOuterMiddleAAInner(), - } + return &OuterMiddleAA{} } func (x *OuterMiddleAA) Error(b *polyglot.Buffer, err error) { @@ -813,12 +807,12 @@ func (x *OuterMiddleAA) decode(d *polyglot.BufferDecoder) error { if err != nil { return err } - if x.Inner == nil { + if !d.Nil() { x.Inner = NewOuterMiddleAAInner() - } - err = x.Inner.decode(d) - if err != nil { - return err + err = x.Inner.decode(d) + if err != nil { + return err + } } return nil } @@ -892,9 +886,7 @@ type OuterMiddleBB struct { } func NewOuterMiddleBB() *OuterMiddleBB { - return &OuterMiddleBB{ - Inner: NewOuterMiddleBBInner(), - } + return &OuterMiddleBB{} } func (x *OuterMiddleBB) Error(b *polyglot.Buffer, err error) { @@ -936,12 +928,12 @@ func (x *OuterMiddleBB) decode(d *polyglot.BufferDecoder) error { if err != nil { return err } - if x.Inner == nil { + if !d.Nil() { x.Inner = NewOuterMiddleBBInner() - } - err = x.Inner.decode(d) - if err != nil { - return err + err = x.Inner.decode(d) + if err != nil { + return err + } } return nil } @@ -955,10 +947,7 @@ type Outer struct { } func NewOuter() *Outer { - return &Outer{ - A: NewOuterMiddleAA(), - B: NewOuterMiddleBB(), - } + return &Outer{} } func (x *Outer) Error(b *polyglot.Buffer, err error) { @@ -1001,19 +990,19 @@ func (x *Outer) decode(d *polyglot.BufferDecoder) error { if err != nil { return err } - if x.A == nil { + if !d.Nil() { x.A = NewOuterMiddleAA() + err = x.A.decode(d) + if err != nil { + return err + } } - err = x.A.decode(d) - if err != nil { - return err - } - if x.B == nil { + if !d.Nil() { x.B = NewOuterMiddleBB() - } - err = x.B.decode(d) - if err != nil { - return err + err = x.B.decode(d) + if err != nil { + return err + } } return nil } @@ -1475,7 +1464,8 @@ type EchoService interface { Upload(context.Context, *UploadServer) error } -const connectionContextKey int = 1000 +const ConnectionContextKey int = 1000 +const StreamContextKey int = 1001 func SetErrorFlag(flags uint8, error bool) uint8 { return flags | 0x2 @@ -1573,6 +1563,8 @@ func NewServer(echoService EchoService, tlsConfig *tls.Config, logger types.Logg } s.server.SetStreamHandler(func(ctx context.Context, stream *frisbee.Stream) { + streamCtx := context.WithValue(ctx, ConnectionContextKey, stream.Conn()) + streamCtx = context.WithValue(streamCtx, StreamContextKey, stream) p, err := stream.ReadPacket() if err != nil { return @@ -1585,16 +1577,16 @@ func NewServer(echoService EchoService, tlsConfig *tls.Config, logger types.Logg } switch open.operation { case 11: - s.createEchoStreamServer(ctx, echoService, stream) + s.createEchoStreamServer(streamCtx, echoService, stream) case 13: - s.createSearchServer(ctx, echoService, stream) + s.createSearchServer(streamCtx, echoService, stream) case 14: - s.createUploadServer(ctx, echoService, stream) + s.createUploadServer(streamCtx, echoService, stream) } }) s.server.ConnContext = func(ctx context.Context, conn *frisbee.Async) context.Context { - return context.WithValue(ctx, connectionContextKey, conn) + return context.WithValue(ctx, ConnectionContextKey, conn) } return s, nil @@ -1949,6 +1941,8 @@ func (c *subEchoServiceClient) Echo(ctx context.Context, req *Request) (res *Res return } select { + case <-c.client.CloseChannel(): + err = c.client.Error() case res = <-ch: err = res.error case <-ctx.Done(): @@ -2075,6 +2069,8 @@ func (c *subEchoServiceClient) Testy(ctx context.Context, req *SearchResponse) ( return } select { + case <-c.client.CloseChannel(): + err = c.client.Error() case res = <-ch: err = res.error case <-ctx.Done(): diff --git a/templates/client.templ b/templates/client.templ index f32c747..8f67530 100644 --- a/templates/client.templ +++ b/templates/client.templ @@ -237,6 +237,8 @@ func (e CloseError) Error() string { return } select { + case <-c.client.CloseChannel(): + err = c.client.Error() case res = <- ch: err = res.error case <- ctx.Done(): diff --git a/templates/server.templ b/templates/server.templ index 810edc9..baaa9cc 100644 --- a/templates/server.templ +++ b/templates/server.templ @@ -82,6 +82,10 @@ func NewServer({{ GetServerFields .services }}, tlsConfig *tls.Config, logger ty } }) + s.server.StreamContext = func(ctx context.Context, stream *frisbee.Stream) context.Context { + return context.WithValue(ctx, ConnectionContextKey, stream.Conn()) + } + {{ end -}} s.server.ConnContext = func (ctx context.Context, conn *frisbee.Async) context.Context {