Skip to content

Commit

Permalink
fix(agent): allow context propagation in grpc messages (#3823)
Browse files Browse the repository at this point in the history
  • Loading branch information
schoren authored Apr 22, 2024
1 parent 2d9591f commit c39c09b
Show file tree
Hide file tree
Showing 21 changed files with 990 additions and 801 deletions.
4 changes: 4 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ trim_trailing_whitespace = false

[COMMIT_EDITMSG]
max_line_length = 0

[*.proto]
indent_style = space
indent_size = 4
42 changes: 6 additions & 36 deletions agent/client/mocks/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ import (
"github.com/avast/retry-go"
"github.com/kubeshop/tracetest/agent/client"
"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -73,13 +70,7 @@ func (s *GrpcServerMock) start(wg *sync.WaitGroup, port int) error {

s.port = lis.Addr().(*net.TCPAddr).Port

server := grpc.NewServer(
grpc.UnaryInterceptor(otelgrpc.UnaryServerInterceptor(
otelgrpc.WithPropagators(
propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}),
),
)),
)
server := grpc.NewServer()
proto.RegisterOrchestratorServer(server, s)

s.server = server
Expand Down Expand Up @@ -121,12 +112,7 @@ func (s *GrpcServerMock) RegisterTriggerAgent(id *proto.AgentIdentification, str

for {
triggerRequest := <-s.triggerChannel
err := telemetry.InjectContextIntoStream(triggerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(triggerRequest.Data)
err := stream.Send(triggerRequest.Data)
if err != nil {
log.Println("could not send trigger request to agent: %w", err)
}
Expand All @@ -150,12 +136,8 @@ func (s *GrpcServerMock) RegisterPollerAgent(id *proto.AgentIdentification, stre

for {
pollerRequest := <-s.pollingChannel
err := telemetry.InjectContextIntoStream(pollerRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(pollerRequest.Data)
err := stream.Send(pollerRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -169,12 +151,8 @@ func (s *GrpcServerMock) RegisterDataStoreConnectionTestAgent(id *proto.AgentIde

for {
dsTestRequest := <-s.dataStoreTestChannel
err := telemetry.InjectContextIntoStream(dsTestRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(dsTestRequest.Data)
err := stream.Send(dsTestRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand All @@ -188,12 +166,8 @@ func (s *GrpcServerMock) RegisterOTLPConnectionTestListener(id *proto.AgentIdent

for {
testRequest := <-s.otlpConnectionTestChannel
err := telemetry.InjectContextIntoStream(testRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(testRequest.Data)
err := stream.Send(testRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand Down Expand Up @@ -230,12 +204,8 @@ func (s *GrpcServerMock) SendPolledSpans(ctx context.Context, result *proto.Poll
func (s *GrpcServerMock) RegisterShutdownListener(_ *proto.AgentIdentification, stream proto.Orchestrator_RegisterShutdownListenerServer) error {
for {
shutdownRequest := <-s.terminationChannel
err := telemetry.InjectContextIntoStream(shutdownRequest.Context, stream)
if err != nil {
log.Println(err.Error())
}

err = stream.Send(shutdownRequest.Data)
err := stream.Send(shutdownRequest.Data)
if err != nil {
log.Println("could not send polling request to agent: %w", err)
}
Expand Down
10 changes: 3 additions & 7 deletions agent/client/workflow_listen_for_ds_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"github.com/kubeshop/tracetest/agent/telemetry"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -48,12 +48,8 @@ func (c *Client) startDataStoreConnectionTestListener(ctx context.Context) error
continue
}

ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
logger.Error("could not extract context from stream", zap.Error(err))
log.Println("could not extract context from stream %w", err)
}

// we want a new context per request, not to reuse the one from the stream
ctx := telemetry.InjectMetadataIntoContext(context.Background(), req.Metadata)
err = c.dataStoreConnectionListener(ctx, &req)
if err != nil {
logger.Error("could not handle data store connection test request", zap.Error(err))
Expand Down
10 changes: 3 additions & 7 deletions agent/client/workflow_listen_for_otlp_connection_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"github.com/kubeshop/tracetest/agent/telemetry"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -48,12 +48,8 @@ func (c *Client) startOTLPConnectionTestListener(ctx context.Context) error {
continue
}

ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
logger.Error("could not extract context from stream", zap.Error(err))
log.Println("could not extract context from stream %w", err)
}

// we want a new context per request, not to reuse the one from the stream
ctx := telemetry.InjectMetadataIntoContext(context.Background(), req.Metadata)
err = c.otlpConnectionTestListener(ctx, &req)
if err != nil {
logger.Error("could not handle otlp connection test request", zap.Error(err))
Expand Down
16 changes: 6 additions & 10 deletions agent/client/workflow_listen_for_poll_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"github.com/kubeshop/tracetest/agent/telemetry"
"go.uber.org/zap"
)

Expand All @@ -24,8 +24,8 @@ func (c *Client) startPollerListener(ctx context.Context) error {

go func() {
for {
resp := proto.PollingRequest{}
err := stream.RecvMsg(&resp)
req := proto.PollingRequest{}
err := stream.RecvMsg(&req)
if err != nil {
logger.Error("could not get message from poller stream", zap.Error(err))
}
Expand All @@ -47,13 +47,9 @@ func (c *Client) startPollerListener(ctx context.Context) error {
continue
}

ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
logger.Error("could not extract context from stream", zap.Error(err))
log.Println("could not extract context from stream %w", err)
}

err = c.pollListener(ctx, &resp)
// we want a new context per request, not to reuse the one from the stream
ctx := telemetry.InjectMetadataIntoContext(context.Background(), req.Metadata)
err = c.pollListener(ctx, &req)
if err != nil {
logger.Error("could not handle poll request", zap.Error(err))
fmt.Println(err.Error())
Expand Down
10 changes: 6 additions & 4 deletions agent/client/workflow_listen_for_stop_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/agent/telemetry"
"go.uber.org/zap"
)

Expand All @@ -25,8 +26,8 @@ func (c *Client) startStopListener(ctx context.Context) error {

go func() {
for {
resp := proto.StopRequest{}
err := stream.RecvMsg(&resp)
req := proto.StopRequest{}
err := stream.RecvMsg(&req)
if err != nil {
logger.Error("could not get message from stop stream", zap.Error(err))
}
Expand All @@ -48,8 +49,9 @@ func (c *Client) startStopListener(ctx context.Context) error {
continue
}

// TODO: get context from request
err = c.stopListener(context.Background(), &resp)
// we want a new context per request, not to reuse the one from the stream
ctx := telemetry.InjectMetadataIntoContext(context.Background(), req.Metadata)
err = c.stopListener(ctx, &req)
if err != nil {
logger.Error("could not handle stop request", zap.Error(err))
fmt.Println(err.Error())
Expand Down
16 changes: 6 additions & 10 deletions agent/client/workflow_listen_for_trigger_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"time"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/server/telemetry"
"github.com/kubeshop/tracetest/agent/telemetry"
"go.uber.org/zap"
)

Expand All @@ -25,8 +25,8 @@ func (c *Client) startTriggerListener(ctx context.Context) error {

go func() {
for {
resp := proto.TriggerRequest{}
err := stream.RecvMsg(&resp)
req := proto.TriggerRequest{}
err := stream.RecvMsg(&req)
if err != nil {
logger.Error("could not get message from stop stream", zap.Error(err))
}
Expand All @@ -48,13 +48,9 @@ func (c *Client) startTriggerListener(ctx context.Context) error {
continue
}

ctx, err := telemetry.ExtractContextFromStream(stream)
if err != nil {
logger.Error("could not extract context from stream", zap.Error(err))
log.Println("could not extract context from stream %w", err)
}

err = c.triggerListener(ctx, &resp)
// we want a new context per request, not to reuse the one from the stream
ctx := telemetry.InjectMetadataIntoContext(context.Background(), req.Metadata)
err = c.triggerListener(ctx, &req)
if err != nil {
logger.Error("could not handle trigger request", zap.Error(err))
fmt.Println(err.Error())
Expand Down
2 changes: 2 additions & 0 deletions agent/client/workflow_send_ds_connection_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"fmt"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/agent/telemetry"
)

func (c *Client) SendDataStoreConnectionResult(ctx context.Context, response *proto.DataStoreConnectionTestResponse) error {
client := proto.NewOrchestratorClient(c.conn)

response.AgentIdentification = c.sessionConfig.AgentIdentification
response.Metadata = telemetry.ExtractMetadataFromContext(ctx)

_, err := client.SendDataStoreConnectionTestResult(ctx, response)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions agent/client/workflow_send_otlp_connection_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"fmt"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/agent/telemetry"
)

func (c *Client) SendOTLPConnectionResult(ctx context.Context, response *proto.OTLPConnectionTestResponse) error {
client := proto.NewOrchestratorClient(c.conn)

response.AgentIdentification = c.sessionConfig.AgentIdentification
response.Metadata = telemetry.ExtractMetadataFromContext(ctx)

_, err := client.SendOTLPConnectionTestResult(ctx, response)
if err != nil {
Expand Down
8 changes: 5 additions & 3 deletions agent/client/workflow_send_trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ import (
"fmt"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/agent/telemetry"
)

func (c *Client) SendTrace(ctx context.Context, pollingResponse *proto.PollingResponse) error {
func (c *Client) SendTrace(ctx context.Context, response *proto.PollingResponse) error {
client := proto.NewOrchestratorClient(c.conn)

pollingResponse.AgentIdentification = c.sessionConfig.AgentIdentification
response.AgentIdentification = c.sessionConfig.AgentIdentification
response.Metadata = telemetry.ExtractMetadataFromContext(ctx)

_, err := client.SendPolledSpans(ctx, pollingResponse)
_, err := client.SendPolledSpans(ctx, response)
if err != nil {
return fmt.Errorf("could not send polled spans result request: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions agent/client/workflow_send_trigger_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"

"github.com/kubeshop/tracetest/agent/proto"
"github.com/kubeshop/tracetest/agent/telemetry"
)

func (c *Client) SendTriggerResponse(ctx context.Context, response *proto.TriggerResponse) error {
Expand All @@ -14,6 +15,8 @@ func (c *Client) SendTriggerResponse(ctx context.Context, response *proto.Trigge
response.AgentIdentification = c.sessionConfig.AgentIdentification
}

response.Metadata = telemetry.ExtractMetadataFromContext(ctx)

_, err := client.SendTriggerResult(ctx, response)
if err != nil {
return fmt.Errorf("could not send trigger result request: %w", err)
Expand Down
Loading

0 comments on commit c39c09b

Please sign in to comment.