Skip to content

Commit

Permalink
Merge pull request #54 from matrix-org/SimonBrandner/feat/rtcp-forward
Browse files Browse the repository at this point in the history
Implement RTCP forwarding in the refactored version of the SFU
  • Loading branch information
daniel-abramov authored Dec 5, 2022
2 parents 78e9e99 + 5f90f21 commit 1b6743e
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 250 deletions.
78 changes: 78 additions & 0 deletions pkg/conference/data_channel_message_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package conference

import (
"github.com/pion/webrtc/v3"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/event"
)

// Handle the `SFUMessage` event from the DataChannel message.
func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) {
participant.logger.Info("Received select request over DC")

// Find tracks based on what we were asked for.
tracks := c.getTracks(msg.Start)

// Let's check if we have all the tracks that we were asked for are there.
// If not, we will list which are not available (later on we must inform participant
// about it unless the participant retries it).
if len(tracks) != len(msg.Start) {
for _, expected := range msg.Start {
found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool {
return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID
})

if found == -1 {
c.logger.Warnf("Track not found: %s", expected.TrackID)
}
}
}

// Subscribe to the found tracks.
for _, track := range tracks {
if err := participant.peer.SubscribeTo(track); err != nil {
participant.logger.Errorf("Failed to subscribe to track: %v", err)
return
}
}
}

func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) {
participant.logger.Info("Received SDP answer over DC")

if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil {
participant.logger.Errorf("Failed to set SDP answer: %v", err)
return
}
}

func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) {
participant.logger.Info("Received SDP offer over DC")

answer, err := participant.peer.ProcessSDPOffer(msg.SDP)
if err != nil {
participant.logger.Errorf("Failed to set SDP offer: %v", err)
return
}

participant.streamMetadata = msg.Metadata

participant.sendDataChannelMessage(event.SFUMessage{
Op: event.SFUOperationAnswer,
SDP: answer.SDP,
Metadata: c.getAvailableStreamsFor(participant.id),
})
}

func (c *Conference) processUnpublishDCMessage(participant *Participant) {
participant.logger.Info("Received unpublish over DC")
}

func (c *Conference) processAliveDCMessage(participant *Participant) {
participant.peer.ProcessHeartbeat()
}

func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) {
participant.streamMetadata = msg.Metadata
c.resendMetadataToAllExcept(participant.id)
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *
logger: logger,
remoteSessionID: inviteEvent.SenderSessionID,
streamMetadata: inviteEvent.SDPStreamMetadata,
publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP),
publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack),
}

c.participants[participantID] = participant
Expand Down
85 changes: 85 additions & 0 deletions pkg/conference/messsage_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package conference

import (
"errors"

"github.com/matrix-org/waterfall/pkg/common"
"github.com/matrix-org/waterfall/pkg/peer"
"maunium.net/go/mautrix/event"
)

// Listen on messages from incoming channels and process them.
// This is essentially the main loop of the conference.
// If this function returns, the conference is over.
func (c *Conference) processMessages() {
for {
select {
case msg := <-c.peerMessages:
c.processPeerMessage(msg)
case msg := <-c.matrixMessages.Channel:
c.processMatrixMessage(msg)
}

// If there are no more participants, stop the conference.
if len(c.participants) == 0 {
c.logger.Info("No more participants, stopping the conference")
// Close the channel so that the sender can't push any messages.
unreadMessages := c.matrixMessages.Close()

// Send the information that we ended to the owner and pass the message
// that we did not process (so that we don't drop it silently).
c.endNotifier.Notify(unreadMessages)
return
}
}
}

// Process a message from a local peer.
func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) {
participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant"))
if participant == nil {
return
}

// Since Go does not support ADTs, we have to use a switch statement to
// determine the actual type of the message.
switch msg := message.Content.(type) {
case peer.JoinedTheCall:
c.processJoinedTheCallMessage(participant, msg)
case peer.LeftTheCall:
c.processLeftTheCallMessage(participant, msg)
case peer.NewTrackPublished:
c.processNewTrackPublishedMessage(participant, msg)
case peer.PublishedTrackFailed:
c.processPublishedTrackFailedMessage(participant, msg)
case peer.NewICECandidate:
c.processNewICECandidateMessage(participant, msg)
case peer.ICEGatheringComplete:
c.processICEGatheringCompleteMessage(participant, msg)
case peer.RenegotiationRequired:
c.processRenegotiationRequiredMessage(participant, msg)
case peer.DataChannelMessage:
c.processDataChannelMessage(participant, msg)
case peer.DataChannelAvailable:
c.processDataChannelAvailableMessage(participant, msg)
case peer.RTCPReceived:
c.processForwardRTCPMessage(msg)
default:
c.logger.Errorf("Unknown message type: %T", msg)
}
}

func (c *Conference) processMatrixMessage(msg MatrixMessage) {
switch ev := msg.Content.(type) {
case *event.CallInviteEventContent:
c.onNewParticipant(msg.Sender, ev)
case *event.CallCandidatesEventContent:
c.onCandidates(msg.Sender, ev)
case *event.CallSelectAnswerEventContent:
c.onSelectAnswer(msg.Sender, ev)
case *event.CallHangupEventContent:
c.onHangup(msg.Sender, ev)
default:
c.logger.Errorf("Unexpected event type: %T", ev)
}
}
10 changes: 9 additions & 1 deletion pkg/conference/participant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package conference

import (
"encoding/json"
"time"

"github.com/matrix-org/waterfall/pkg/peer"
"github.com/matrix-org/waterfall/pkg/signaling"
Expand All @@ -19,14 +20,21 @@ type ParticipantID struct {
CallID string
}

type PublishedTrack struct {
track *webrtc.TrackLocalStaticRTP
// The time when we sent the last PLI to the sender. We store this to avoid
// spamming the sender.
lastPLITimestamp time.Time
}

// Participant represents a participant in the conference.
type Participant struct {
id ParticipantID
logger *logrus.Entry
peer *peer.Peer[ParticipantID]
remoteSessionID id.SessionID
streamMetadata event.CallSDPStreamMetadata
publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP
publishedTracks map[event.SFUTrackDescription]PublishedTrack
}

func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient {
Expand Down
128 changes: 128 additions & 0 deletions pkg/conference/peer_message_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package conference

import (
"encoding/json"
"time"

"github.com/matrix-org/waterfall/pkg/peer"
"github.com/pion/webrtc/v3"
"maunium.net/go/mautrix/event"
)

func (c *Conference) processJoinedTheCallMessage(participant *Participant, message peer.JoinedTheCall) {
participant.logger.Info("Joined the call")
}

func (c *Conference) processLeftTheCallMessage(participant *Participant, msg peer.LeftTheCall) {
participant.logger.Info("Left the call: %s", msg.Reason)
c.removeParticipant(participant.id)
c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason)
}

func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) {
participant.logger.Infof("Published new track: %s", msg.Track.ID())
key := event.SFUTrackDescription{
StreamID: msg.Track.StreamID(),
TrackID: msg.Track.ID(),
}

if _, ok := participant.publishedTracks[key]; ok {
c.logger.Errorf("Track already published: %v", key)
return
}

participant.publishedTracks[key] = PublishedTrack{track: msg.Track}
c.resendMetadataToAllExcept(participant.id)
}

func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) {
participant.logger.Infof("Failed published track: %s", msg.Track.ID())
delete(participant.publishedTracks, event.SFUTrackDescription{
StreamID: msg.Track.StreamID(),
TrackID: msg.Track.ID(),
})

for _, otherParticipant := range c.participants {
if otherParticipant.id == participant.id {
continue
}

otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track})
}

c.resendMetadataToAllExcept(participant.id)
}

func (c *Conference) processNewICECandidateMessage(participant *Participant, msg peer.NewICECandidate) {
participant.logger.Debug("Received a new local ICE candidate")

// Convert WebRTC ICE candidate to Matrix ICE candidate.
jsonCandidate := msg.Candidate.ToJSON()
candidates := []event.CallCandidate{{
Candidate: jsonCandidate.Candidate,
SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex),
SDPMID: *jsonCandidate.SDPMid,
}}
c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates)
}

func (c *Conference) processICEGatheringCompleteMessage(participant *Participant, msg peer.ICEGatheringComplete) {
participant.logger.Info("Completed local ICE gathering")

// Send an empty array of candidates to indicate that ICE gathering is complete.
c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient())
}

func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) {
participant.logger.Info("Started renegotiation")
participant.sendDataChannelMessage(event.SFUMessage{
Op: event.SFUOperationOffer,
SDP: msg.Offer.SDP,
Metadata: c.getAvailableStreamsFor(participant.id),
})
}

func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) {
participant.logger.Debug("Received data channel message")
var sfuMessage event.SFUMessage
if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil {
c.logger.Errorf("Failed to unmarshal SFU message: %v", err)
return
}

switch sfuMessage.Op {
case event.SFUOperationSelect:
c.processSelectDCMessage(participant, sfuMessage)
case event.SFUOperationAnswer:
c.processAnswerDCMessage(participant, sfuMessage)
case event.SFUOperationPublish:
c.processPublishDCMessage(participant, sfuMessage)
case event.SFUOperationUnpublish:
c.processUnpublishDCMessage(participant)
case event.SFUOperationAlive:
c.processAliveDCMessage(participant)
case event.SFUOperationMetadata:
c.processMetadataDCMessage(participant, sfuMessage)
}
}

func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) {
participant.logger.Info("Connected data channel")
participant.sendDataChannelMessage(event.SFUMessage{
Op: event.SFUOperationMetadata,
Metadata: c.getAvailableStreamsFor(participant.id),
})
}

func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) {
for _, participant := range c.participants {
for _, publishedTrack := range participant.publishedTracks {
if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID {
err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp)
if err == nil {
publishedTrack.lastPLITimestamp = time.Now()
}
}
}
}
}
Loading

0 comments on commit 1b6743e

Please sign in to comment.