Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[signal] Fix context propagation in signal server #3251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions signal/server/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}

dispatcher, err := dispatcher.NewDispatcher(ctx, meter)
d, err := dispatcher.NewDispatcher(ctx, meter)
if err != nil {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}

s := &Server{
dispatcher: dispatcher,
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
}
Expand All @@ -75,7 +75,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
return &proto.EncryptedMessage{}, nil
}

return s.dispatcher.SendMessage(context.Background(), msg)
return s.dispatcher.SendMessage(ctx, msg)
}

// ConnectStream connects to the exchange stream
Expand All @@ -98,76 +98,81 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)

for {
// read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}

log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

_, err = s.dispatcher.SendMessage(stream.Context(), msg)
if err != nil {
log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
select {
case <-stream.Context().Done():
log.Debugf("stream closed for peer [%s] [streamID %d] due to context cancellation", p.Id, p.StreamID)
return stream.Context().Err()
default:
// read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
} else if err != nil {
return err
}

log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

_, err = s.dispatcher.SendMessage(stream.Context(), msg)
if err != nil {
log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
}
}
}

<-stream.Context().Done()
return stream.Context().Err()
}

func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
log.Debugf("registering new peer")
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
if id, found := meta[proto.HeaderId]; found {
p := peer.NewPeer(id[0], stream)

s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)

return p, nil
} else {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId)
}
} else {
meta, hasMeta := metadata.FromIncomingContext(stream.Context())
if !hasMeta {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
}

id, found := meta[proto.HeaderId]
if !found {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
}

p := peer.NewPeer(id[0], stream)
s.registry.Register(p)
s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
return p, nil
}

func (s *Server) DeregisterPeer(p *peer.Peer) {
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
s.registry.Deregister(p)

s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
}

func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)

getRegistrationStart := time.Now()

// lookup the target peer where the message is going to
if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
// forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil {
log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
// todo respond to the sender?
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
} else {
// in milliseconds
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
}
} else {
dstPeer, found := s.registry.Get(msg.RemoteKey)

if !found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
}

s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()

// forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil {
log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
// todo respond to the sender?
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return
}

// in milliseconds
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
}