From ec5841b3753b42d01c5bed2640640140161123df Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 5 Dec 2022 19:17:41 +0100 Subject: [PATCH] conference: use TrackID as identifier for tracks Theoretically we don't need to use a combination of TrackID and StreamID to uniquely identify tracks inside as long as the GUIDs are used for the tracks. Closes https://github.com/matrix-org/waterfall/issues/56. --- .../data_channel_message_processor.go | 2 +- pkg/conference/matrix_message_processor.go | 2 +- pkg/conference/participant.go | 2 +- pkg/conference/peer_message_processor.go | 21 +++++++------------ pkg/conference/state.go | 3 ++- pkg/peer/messages.go | 5 ++--- pkg/peer/peer.go | 6 +++--- 7 files changed, 17 insertions(+), 24 deletions(-) diff --git a/pkg/conference/data_channel_message_processor.go b/pkg/conference/data_channel_message_processor.go index 4edf49e..e73cc31 100644 --- a/pkg/conference/data_channel_message_processor.go +++ b/pkg/conference/data_channel_message_processor.go @@ -19,7 +19,7 @@ func (c *Conference) processSelectDCMessage(participant *Participant, msg event. if len(tracks) != len(msg.Start) { for _, expected := range msg.Start { found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { - return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID + return track.ID() == expected.TrackID }) if found == -1 { diff --git a/pkg/conference/matrix_message_processor.go b/pkg/conference/matrix_message_processor.go index 691a1f5..20ae788 100644 --- a/pkg/conference/matrix_message_processor.go +++ b/pkg/conference/matrix_message_processor.go @@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * logger: logger, remoteSessionID: inviteEvent.SenderSessionID, streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack), + publishedTracks: make(map[string]PublishedTrack), } c.participants[participantID] = participant diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 757d13d..0bb92f3 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -34,7 +34,7 @@ type Participant struct { peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata - publishedTracks map[event.SFUTrackDescription]PublishedTrack + publishedTracks map[string]PublishedTrack } func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index a815972..18a15d3 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -21,26 +21,19 @@ func (c *Conference) processLeftTheCallMessage(participant *Participant, msg pee func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) { participant.logger.Infof("Published new track: %s", msg.Track.ID()) - key := event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - } - if _, ok := participant.publishedTracks[key]; ok { - c.logger.Errorf("Track already published: %v", key) + if _, ok := participant.publishedTracks[msg.Track.ID()]; ok { + c.logger.Errorf("Track already published: %v", msg.Track.ID()) return } - participant.publishedTracks[key] = PublishedTrack{track: msg.Track} + participant.publishedTracks[msg.Track.ID()] = PublishedTrack{track: msg.Track} c.resendMetadataToAllExcept(participant.id) } func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) { participant.logger.Infof("Failed published track: %s", msg.Track.ID()) - delete(participant.publishedTracks, event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - }) + delete(participant.publishedTracks, msg.Track.ID()) for _, otherParticipant := range c.participants { if otherParticipant.id == participant.id { @@ -116,9 +109,9 @@ func (c *Conference) processDataChannelAvailableMessage(participant *Participant func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { for _, participant := range c.participants { - for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { - err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) + for id, publishedTrack := range participant.publishedTracks { + if id == msg.TrackID { + err := participant.peer.WriteRTCP(msg.Packets, msg.TrackID, publishedTrack.lastPLITimestamp) if err == nil { publishedTrack.lastPLITimestamp = time.Now() } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index de43c00..838479f 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -59,6 +59,7 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { for _, publishedTrack := range participant.publishedTracks { obsoleteTracks = append(obsoleteTracks, publishedTrack.track) } + for _, otherParticipant := range c.participants { otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) } @@ -98,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt for _, participant := range c.participants { // Check if this participant has any of the tracks that we're looking for. for _, identifier := range identifiers { - if track, ok := participant.publishedTracks[identifier]; ok { + if track, ok := participant.publishedTracks[identifier.TrackID]; ok { tracks = append(tracks, track.track) } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 0593d72..2712518 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -41,7 +41,6 @@ type DataChannelMessage struct { type DataChannelAvailable struct{} type RTCPReceived struct { - Packets []rtcp.Packet - StreamID string - TrackID string + TrackID string + Packets []rtcp.Packet } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index a226633..7961e1b 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -112,21 +112,21 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { p.logger.WithError(err).Warn("failed to read RTCP on track") } - p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) + p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID()}) } }() return nil } -func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp time.Time) error { +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, trackID string, lastPLITimestamp time.Time) error { const minimalPLIInterval = time.Millisecond * 500 packetsToSend := []rtcp.Packet{} var mediaSSRC uint32 receivers := p.peerConnection.GetReceivers() receiverIndex := slices.IndexFunc(receivers, func(receiver *webrtc.RTPReceiver) bool { - return receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID + return receiver.Track().ID() == trackID }) if receiverIndex == -1 {