diff --git a/pkg/conference/data_channel_message_processor.go b/pkg/conference/data_channel_message_processor.go new file mode 100644 index 0000000..4edf49e --- /dev/null +++ b/pkg/conference/data_channel_message_processor.go @@ -0,0 +1,78 @@ +package conference + +import ( + "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/event" +) + +// Handle the `SFUMessage` event from the DataChannel message. +func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received select request over DC") + + // Find tracks based on what we were asked for. + tracks := c.getTracks(msg.Start) + + // Let's check if we have all the tracks that we were asked for are there. + // If not, we will list which are not available (later on we must inform participant + // about it unless the participant retries it). + if len(tracks) != len(msg.Start) { + for _, expected := range msg.Start { + found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { + return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID + }) + + if found == -1 { + c.logger.Warnf("Track not found: %s", expected.TrackID) + } + } + } + + // Subscribe to the found tracks. + for _, track := range tracks { + if err := participant.peer.SubscribeTo(track); err != nil { + participant.logger.Errorf("Failed to subscribe to track: %v", err) + return + } + } +} + +func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP answer over DC") + + if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil { + participant.logger.Errorf("Failed to set SDP answer: %v", err) + return + } +} + +func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP offer over DC") + + answer, err := participant.peer.ProcessSDPOffer(msg.SDP) + if err != nil { + participant.logger.Errorf("Failed to set SDP offer: %v", err) + return + } + + participant.streamMetadata = msg.Metadata + + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationAnswer, + SDP: answer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processUnpublishDCMessage(participant *Participant) { + participant.logger.Info("Received unpublish over DC") +} + +func (c *Conference) processAliveDCMessage(participant *Participant) { + participant.peer.ProcessHeartbeat() +} + +func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) { + participant.streamMetadata = msg.Metadata + c.resendMetadataToAllExcept(participant.id) +} diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix_message_processor.go similarity index 98% rename from pkg/conference/matrix.go rename to pkg/conference/matrix_message_processor.go index 61bd40c..691a1f5 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix_message_processor.go @@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * logger: logger, remoteSessionID: inviteEvent.SenderSessionID, streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), + publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack), } c.participants[participantID] = participant diff --git a/pkg/conference/messsage_processor.go b/pkg/conference/messsage_processor.go new file mode 100644 index 0000000..3865aac --- /dev/null +++ b/pkg/conference/messsage_processor.go @@ -0,0 +1,85 @@ +package conference + +import ( + "errors" + + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "maunium.net/go/mautrix/event" +) + +// Listen on messages from incoming channels and process them. +// This is essentially the main loop of the conference. +// If this function returns, the conference is over. +func (c *Conference) processMessages() { + for { + select { + case msg := <-c.peerMessages: + c.processPeerMessage(msg) + case msg := <-c.matrixMessages.Channel: + c.processMatrixMessage(msg) + } + + // If there are no more participants, stop the conference. + if len(c.participants) == 0 { + c.logger.Info("No more participants, stopping the conference") + // Close the channel so that the sender can't push any messages. + unreadMessages := c.matrixMessages.Close() + + // Send the information that we ended to the owner and pass the message + // that we did not process (so that we don't drop it silently). + c.endNotifier.Notify(unreadMessages) + return + } + } +} + +// Process a message from a local peer. +func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { + participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) + if participant == nil { + return + } + + // Since Go does not support ADTs, we have to use a switch statement to + // determine the actual type of the message. + switch msg := message.Content.(type) { + case peer.JoinedTheCall: + c.processJoinedTheCallMessage(participant, msg) + case peer.LeftTheCall: + c.processLeftTheCallMessage(participant, msg) + case peer.NewTrackPublished: + c.processNewTrackPublishedMessage(participant, msg) + case peer.PublishedTrackFailed: + c.processPublishedTrackFailedMessage(participant, msg) + case peer.NewICECandidate: + c.processNewICECandidateMessage(participant, msg) + case peer.ICEGatheringComplete: + c.processICEGatheringCompleteMessage(participant, msg) + case peer.RenegotiationRequired: + c.processRenegotiationRequiredMessage(participant, msg) + case peer.DataChannelMessage: + c.processDataChannelMessage(participant, msg) + case peer.DataChannelAvailable: + c.processDataChannelAvailableMessage(participant, msg) + case peer.RTCPReceived: + c.processForwardRTCPMessage(msg) + default: + c.logger.Errorf("Unknown message type: %T", msg) + } +} + +func (c *Conference) processMatrixMessage(msg MatrixMessage) { + switch ev := msg.Content.(type) { + case *event.CallInviteEventContent: + c.onNewParticipant(msg.Sender, ev) + case *event.CallCandidatesEventContent: + c.onCandidates(msg.Sender, ev) + case *event.CallSelectAnswerEventContent: + c.onSelectAnswer(msg.Sender, ev) + case *event.CallHangupEventContent: + c.onHangup(msg.Sender, ev) + default: + c.logger.Errorf("Unexpected event type: %T", ev) + } +} diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 747a267..59c2c9d 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -2,6 +2,7 @@ package conference import ( "encoding/json" + "time" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -19,6 +20,13 @@ type ParticipantID struct { CallID string } +type PublishedTrack struct { + track *webrtc.TrackLocalStaticRTP + // The time when we sent the last PLI to the sender. We store this to avoid + // spamming the sender. + lastPLITimestamp time.Time +} + // Participant represents a participant in the conference. type Participant struct { id ParticipantID @@ -26,7 +34,7 @@ type Participant struct { peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata - publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP + publishedTracks map[event.SFUTrackDescription]PublishedTrack } func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go new file mode 100644 index 0000000..a815972 --- /dev/null +++ b/pkg/conference/peer_message_processor.go @@ -0,0 +1,128 @@ +package conference + +import ( + "encoding/json" + "time" + + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" +) + +func (c *Conference) processJoinedTheCallMessage(participant *Participant, message peer.JoinedTheCall) { + participant.logger.Info("Joined the call") +} + +func (c *Conference) processLeftTheCallMessage(participant *Participant, msg peer.LeftTheCall) { + participant.logger.Info("Left the call: %s", msg.Reason) + c.removeParticipant(participant.id) + c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) +} + +func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) { + participant.logger.Infof("Published new track: %s", msg.Track.ID()) + key := event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + } + + if _, ok := participant.publishedTracks[key]; ok { + c.logger.Errorf("Track already published: %v", key) + return + } + + participant.publishedTracks[key] = PublishedTrack{track: msg.Track} + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) { + participant.logger.Infof("Failed published track: %s", msg.Track.ID()) + delete(participant.publishedTracks, event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + }) + + for _, otherParticipant := range c.participants { + if otherParticipant.id == participant.id { + continue + } + + otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) + } + + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processNewICECandidateMessage(participant *Participant, msg peer.NewICECandidate) { + participant.logger.Debug("Received a new local ICE candidate") + + // Convert WebRTC ICE candidate to Matrix ICE candidate. + jsonCandidate := msg.Candidate.ToJSON() + candidates := []event.CallCandidate{{ + Candidate: jsonCandidate.Candidate, + SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), + SDPMID: *jsonCandidate.SDPMid, + }} + c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) +} + +func (c *Conference) processICEGatheringCompleteMessage(participant *Participant, msg peer.ICEGatheringComplete) { + participant.logger.Info("Completed local ICE gathering") + + // Send an empty array of candidates to indicate that ICE gathering is complete. + c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) +} + +func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) { + participant.logger.Info("Started renegotiation") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationOffer, + SDP: msg.Offer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) { + participant.logger.Debug("Received data channel message") + var sfuMessage event.SFUMessage + if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { + c.logger.Errorf("Failed to unmarshal SFU message: %v", err) + return + } + + switch sfuMessage.Op { + case event.SFUOperationSelect: + c.processSelectDCMessage(participant, sfuMessage) + case event.SFUOperationAnswer: + c.processAnswerDCMessage(participant, sfuMessage) + case event.SFUOperationPublish: + c.processPublishDCMessage(participant, sfuMessage) + case event.SFUOperationUnpublish: + c.processUnpublishDCMessage(participant) + case event.SFUOperationAlive: + c.processAliveDCMessage(participant) + case event.SFUOperationMetadata: + c.processMetadataDCMessage(participant, sfuMessage) + } +} + +func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) { + participant.logger.Info("Connected data channel") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { + err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) + if err == nil { + publishedTrack.lastPLITimestamp = time.Now() + } + } + } + } +} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go deleted file mode 100644 index 8a5b681..0000000 --- a/pkg/conference/processor.go +++ /dev/null @@ -1,220 +0,0 @@ -package conference - -import ( - "encoding/json" - "errors" - - "github.com/matrix-org/waterfall/pkg/common" - "github.com/matrix-org/waterfall/pkg/peer" - "github.com/pion/webrtc/v3" - "golang.org/x/exp/slices" - "maunium.net/go/mautrix/event" -) - -// Listen on messages from incoming channels and process them. -// This is essentially the main loop of the conference. -// If this function returns, the conference is over. -func (c *Conference) processMessages() { - for { - select { - case msg := <-c.peerMessages: - c.processPeerMessage(msg) - case msg := <-c.matrixMessages.Channel: - c.processMatrixMessage(msg) - } - - // If there are no more participants, stop the conference. - if len(c.participants) == 0 { - c.logger.Info("No more participants, stopping the conference") - // Close the channel so that the sender can't push any messages. - unreadMessages := c.matrixMessages.Close() - - // Send the information that we ended to the owner and pass the message - // that we did not process (so that we don't drop it silently). - c.endNotifier.Notify(unreadMessages) - return - } - } -} - -// Process a message from a local peer. -func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { - participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) - if participant == nil { - return - } - - // Since Go does not support ADTs, we have to use a switch statement to - // determine the actual type of the message. - switch msg := message.Content.(type) { - case peer.JoinedTheCall: - participant.logger.Info("Joined the call") - - case peer.LeftTheCall: - participant.logger.Info("Left the call: %s", msg.Reason) - c.removeParticipant(message.Sender) - c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) - - case peer.NewTrackPublished: - participant.logger.Infof("Published new track: %s", msg.Track.ID()) - key := event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - } - - if _, ok := participant.publishedTracks[key]; ok { - c.logger.Errorf("Track already published: %v", key) - return - } - - participant.publishedTracks[key] = msg.Track - c.resendMetadataToAllExcept(participant.id) - - case peer.PublishedTrackFailed: - participant.logger.Infof("Failed published track: %s", msg.Track.ID()) - delete(participant.publishedTracks, event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - }) - - for _, otherParticipant := range c.participants { - if otherParticipant.id == participant.id { - continue - } - - otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) - } - - c.resendMetadataToAllExcept(participant.id) - - case peer.NewICECandidate: - participant.logger.Debug("Received a new local ICE candidate") - - // Convert WebRTC ICE candidate to Matrix ICE candidate. - jsonCandidate := msg.Candidate.ToJSON() - candidates := []event.CallCandidate{{ - Candidate: jsonCandidate.Candidate, - SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), - SDPMID: *jsonCandidate.SDPMid, - }} - c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) - - case peer.ICEGatheringComplete: - participant.logger.Info("Completed local ICE gathering") - - // Send an empty array of candidates to indicate that ICE gathering is complete. - c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) - - case peer.RenegotiationRequired: - participant.logger.Info("Started renegotiation") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationOffer, - SDP: msg.Offer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - - case peer.DataChannelMessage: - participant.logger.Debug("Received data channel message") - var sfuMessage event.SFUMessage - if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { - c.logger.Errorf("Failed to unmarshal SFU message: %v", err) - return - } - - c.handleDataChannelMessage(participant, sfuMessage) - - case peer.DataChannelAvailable: - participant.logger.Info("Connected data channel") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - - default: - c.logger.Errorf("Unknown message type: %T", msg) - } -} - -// Handle the `SFUMessage` event from the DataChannel message. -func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessage event.SFUMessage) { - switch sfuMessage.Op { - case event.SFUOperationSelect: - participant.logger.Info("Received select request over DC") - - // Find tracks based on what we were asked for. - tracks := c.getTracks(sfuMessage.Start) - - // Let's check if we have all the tracks that we were asked for are there. - // If not, we will list which are not available (later on we must inform participant - // about it unless the participant retries it). - if len(tracks) != len(sfuMessage.Start) { - for _, expected := range sfuMessage.Start { - found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { - return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID - }) - - if found == -1 { - c.logger.Warnf("Track not found: %s", expected.TrackID) - } - } - } - - // Subscribe to the found tracks. - for _, track := range tracks { - if err := participant.peer.SubscribeTo(track); err != nil { - participant.logger.Errorf("Failed to subscribe to track: %v", err) - return - } - } - - case event.SFUOperationAnswer: - participant.logger.Info("Received SDP answer over DC") - - if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { - participant.logger.Errorf("Failed to set SDP answer: %v", err) - return - } - - case event.SFUOperationPublish: - participant.logger.Info("Received SDP offer over DC") - - answer, err := participant.peer.ProcessSDPOffer(sfuMessage.SDP) - if err != nil { - participant.logger.Errorf("Failed to set SDP offer: %v", err) - return - } - - participant.streamMetadata = sfuMessage.Metadata - - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: answer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - - case event.SFUOperationUnpublish: - participant.logger.Info("Received unpublish over DC") - - case event.SFUOperationAlive: - participant.peer.ProcessHeartbeat() - - case event.SFUOperationMetadata: - participant.streamMetadata = sfuMessage.Metadata - c.resendMetadataToAllExcept(participant.id) - } -} - -func (c *Conference) processMatrixMessage(msg MatrixMessage) { - switch ev := msg.Content.(type) { - case *event.CallInviteEventContent: - c.onNewParticipant(msg.Sender, ev) - case *event.CallCandidatesEventContent: - c.onCandidates(msg.Sender, ev) - case *event.CallSelectAnswerEventContent: - c.onSelectAnswer(msg.Sender, ev) - case *event.CallHangupEventContent: - c.onHangup(msg.Sender, ev) - default: - c.logger.Errorf("Unexpected event type: %T", ev) - } -} diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 5acf97d..de43c00 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -6,7 +6,6 @@ import ( "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" ) @@ -56,7 +55,10 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { c.resendMetadataToAllExcept(participantID) // Remove the participant's tracks from all participants who might have subscribed to them. - obsoleteTracks := maps.Values(participant.publishedTracks) + obsoleteTracks := []*webrtc.TrackLocalStaticRTP{} + for _, publishedTrack := range participant.publishedTracks { + obsoleteTracks = append(obsoleteTracks, publishedTrack.track) + } for _, otherParticipant := range c.participants { otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) } @@ -72,7 +74,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. // Now, find out which of published tracks belong to the streams for which we have metadata // available and construct a metadata map for a given participant based on that. for _, track := range participant.publishedTracks { - trackID, streamID := track.ID(), track.StreamID() + trackID, streamID := track.track.ID(), track.track.StreamID() if metadata, ok := streamsMetadata[streamID]; ok { metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} @@ -97,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt // Check if this participant has any of the tracks that we're looking for. for _, identifier := range identifiers { if track, ok := participant.publishedTracks[identifier]; ok { - tracks = append(tracks, track) + tracks = append(tracks, track.track) } } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 6a05a82..0593d72 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,6 +1,7 @@ package peer import ( + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -38,3 +39,9 @@ type DataChannelMessage struct { } type DataChannelAvailable struct{} + +type RTCPReceived struct { + Packets []rtcp.Packet + StreamID string + TrackID string +} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index ee43324..a226633 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -2,24 +2,28 @@ package peer import ( "errors" + "io" "sync" "time" "github.com/matrix-org/waterfall/pkg/common" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" ) var ( ErrCantCreatePeerConnection = errors.New("can't create peer connection") - ErrCantSetRemoteDecsription = errors.New("can't set remote description") + ErrCantSetRemoteDescription = errors.New("can't set remote description") ErrCantCreateAnswer = errors.New("can't create answer") ErrCantSetLocalDescription = errors.New("can't set local description") ErrCantCreateLocalDescription = errors.New("can't create local description") ErrDataChannelNotAvailable = errors.New("data channel is not available") ErrDataChannelNotReady = errors.New("data channel is not ready") ErrCantSubscribeToTrack = errors.New("can't subscribe to track") + ErrCantWriteRTCP = errors.New("can't write RTCP") ) // A wrapped representation of the peer connection (single peer in the call). @@ -98,17 +102,77 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { // Before these packets are returned they are processed by interceptors. For things // like NACK this needs to be called. go func() { - rtcpBuf := make([]byte, 1500) for { - if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { - return + packets, _, err := rtpSender.ReadRTCP() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { + return + } + + p.logger.WithError(err).Warn("failed to read RTCP on track") } + + p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) } }() return nil } +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp time.Time) error { + const minimalPLIInterval = time.Millisecond * 500 + + packetsToSend := []rtcp.Packet{} + var mediaSSRC uint32 + receivers := p.peerConnection.GetReceivers() + receiverIndex := slices.IndexFunc(receivers, func(receiver *webrtc.RTPReceiver) bool { + return receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID + }) + + if receiverIndex == -1 { + p.logger.Error("failed to find track to write RTCP on") + return ErrCantWriteRTCP + } else { + mediaSSRC = uint32(receivers[receiverIndex].Track().SSRC()) + } + + for _, packet := range packets { + switch typedPacket := packet.(type) { + // We mung the packets here, so that the SSRCs match what the + // receiver expects: + // The media SSRC is the SSRC of the media about which the packet is + // reporting; therefore, we mung it to be the SSRC of the publishing + // participant's track. Without this, it would be SSRC of the SFU's + // track which isn't right + case *rtcp.PictureLossIndication: + // Since we sometimes spam the sender with PLIs, make sure we don't send + // them way too often + if time.Now().UnixNano()-lastPLITimestamp.UnixNano() < minimalPLIInterval.Nanoseconds() { + continue + } + + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + case *rtcp.FullIntraRequest: + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + } + + packetsToSend = append(packetsToSend, packet) + } + + if len(packetsToSend) != 0 { + if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil { + if !errors.Is(err, io.ErrClosedPipe) { + p.logger.WithError(err).Error("failed to write RTCP on track") + return err + } + } + } + + return nil +} + // Unsubscribes from the given list of tracks. func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. @@ -169,7 +233,7 @@ func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { }) if err != nil { p.logger.WithError(err).Error("failed to set remote description") - return ErrCantSetRemoteDecsription + return ErrCantSetRemoteDescription } return nil @@ -183,7 +247,7 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, }) if err != nil { p.logger.WithError(err).Error("failed to set remote description") - return nil, ErrCantSetRemoteDecsription + return nil, ErrCantSetRemoteDescription } answer, err := p.peerConnection.CreateAnswer(nil) diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 2a97155..8806689 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -3,9 +3,7 @@ package peer import ( "errors" "io" - "time" - "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -13,22 +11,6 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval. - // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI - // when a viewer requests it. - // - // TODO: Add RTCP handling based on the PR from @SimonBrandner. - go func() { - ticker := time.NewTicker(time.Millisecond * 500) // every 500ms - for range ticker.C { - rtcp := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}} - if err := p.peerConnection.WriteRTCP(rtcp); err != nil && !errors.Is(err, io.ErrClosedPipe) { - p.logger.Errorf("Failed to send RTCP PLI: %v", err) - return - } - } - }() - // Create a local track, all our SFU clients that are subscribed to this // peer (publisher) wil be fed via this track. localTrack, err := webrtc.NewTrackLocalStaticRTP(