diff --git a/signal/server/signal.go b/signal/server/signal.go index 305fd052b2e..abc1c367bc5 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -52,13 +52,13 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher(ctx, meter) + d, err := dispatcher.NewDispatcher(ctx, meter) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } s := &Server{ - dispatcher: dispatcher, + dispatcher: d, registry: peer.NewRegistry(appMetrics), metrics: appMetrics, } @@ -75,7 +75,7 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. return &proto.EncryptedMessage{}, nil } - return s.dispatcher.SendMessage(context.Background(), msg) + return s.dispatcher.SendMessage(ctx, msg) } // ConnectStream connects to the exchange stream @@ -98,76 +98,81 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) for { - // read incoming messages - msg, err := stream.Recv() - if err == io.EOF { - break - } else if err != nil { - return err - } - - log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - - _, err = s.dispatcher.SendMessage(stream.Context(), msg) - if err != nil { - log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + select { + case <-stream.Context().Done(): + log.Debugf("stream closed for peer [%s] [streamID %d] due to context cancellation", p.Id, p.StreamID) + return stream.Context().Err() + default: + // read incoming messages + msg, err := stream.Recv() + if err == io.EOF { + break + } else if err != nil { + return err + } + + log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + + _, err = s.dispatcher.SendMessage(stream.Context(), msg) + if err != nil { + log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + } } } - - <-stream.Context().Done() - return stream.Context().Err() } func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { log.Debugf("registering new peer") - if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta { - if id, found := meta[proto.HeaderId]; found { - p := peer.NewPeer(id[0], stream) - - s.registry.Register(p) - s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) - - return p, nil - } else { - s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) - return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: "+proto.HeaderId) - } - } else { + meta, hasMeta := metadata.FromIncomingContext(stream.Context()) + if !hasMeta { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta))) return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") } + + id, found := meta[proto.HeaderId] + if !found { + s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) + return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId) + } + + p := peer.NewPeer(id[0], stream) + s.registry.Register(p) + s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) + return p, nil } func (s *Server) DeregisterPeer(p *peer.Peer) { log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) s.registry.Deregister(p) - s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - getRegistrationStart := time.Now() // lookup the target peer where the message is going to - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - // forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - // todo respond to the sender? - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - // in milliseconds - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(ctx, 1) - } - } else { + dstPeer, found := s.registry.Get(msg.RemoteKey) + + if !found { s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) // todo respond to the sender? } + + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() + + // forward the message to the target peer + if err := dstPeer.Stream.Send(msg); err != nil { + log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + // todo respond to the sender? + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + return + } + + // in milliseconds + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) }