diff --git a/pkg/conference/participant/tracker.go b/pkg/conference/participant/tracker.go index efb550a..7b0a402 100644 --- a/pkg/conference/participant/tracker.go +++ b/pkg/conference/participant/tracker.go @@ -3,26 +3,34 @@ package participant import ( "fmt" - "github.com/matrix-org/waterfall/pkg/conference/subscription" + "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "golang.org/x/exp/slices" ) +type TrackStoppedMessage struct { + TrackID track.TrackID + OwnerID ID +} + // Tracks participants and their corresponding tracks. // These are grouped together as the field in this structure must be kept synchronized. type Tracker struct { participants map[ID]*Participant - publishedTracks map[TrackID]*PublishedTrack + publishedTracks map[track.TrackID]*track.PublishedTrack[ID] + + publishedTrackStopped chan<- TrackStoppedMessage + conferenceEnded <-chan struct{} } -func NewParticipantTracker() *Tracker { +func NewParticipantTracker(conferenceEnded <-chan struct{}) (*Tracker, <-chan TrackStoppedMessage) { + publishedTrackStopped := make(chan TrackStoppedMessage) return &Tracker{ - participants: make(map[ID]*Participant), - publishedTracks: make(map[TrackID]*PublishedTrack), - } + participants: make(map[ID]*Participant), + publishedTracks: make(map[track.TrackID]*track.PublishedTrack[ID]), + publishedTrackStopped: publishedTrackStopped, + conferenceEnded: conferenceEnded, + }, publishedTrackStopped } // Adds a new participant in the list. @@ -62,9 +70,9 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // Remove the participant's tracks from all participants who might have subscribed to them. streamIdentifiers := make(map[string]bool) for trackID, track := range t.publishedTracks { - if track.Owner == participantID { + if track.Owner() == participantID { // Odd way to add to a set in Go. - streamIdentifiers[track.Info.StreamID] = true + streamIdentifiers[track.Info().StreamID] = true t.RemovePublishedTrack(trackID) } } @@ -72,10 +80,7 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // Go over all subscriptions and remove the participant from them. // TODO: Perhaps we could simply react to the subscrpitions dying and remove them from the list. for _, publishedTrack := range t.publishedTracks { - if subscription, found := publishedTrack.Subscriptions[participantID]; found { - subscription.Unsubscribe() - delete(publishedTrack.Subscriptions, participantID) - } + publishedTrack.Unsubscribe(participantID) } return streamIdentifiers @@ -85,151 +90,98 @@ func (t *Tracker) RemoveParticipant(participantID ID) map[string]bool { // that has been published and that we must take into account from now on. func (t *Tracker) AddPublishedTrack( participantID ID, - info webrtc_ext.TrackInfo, - simulcast webrtc_ext.SimulcastLayer, - metadata TrackMetadata, - outputTrack *webrtc.TrackLocalStaticRTP, -) { - // If this is a new track, let's add it to the list of published and inform participants. - track, found := t.publishedTracks[info.TrackID] - if !found { - layers := []webrtc_ext.SimulcastLayer{} - if simulcast != webrtc_ext.SimulcastLayerNone { - layers = append(layers, simulcast) - } + remoteTrack *webrtc.TrackRemote, + metadata track.TrackMetadata, +) error { + participant := t.participants[participantID] + if participant == nil { + return fmt.Errorf("participant %s does not exist", participantID) + } - t.publishedTracks[info.TrackID] = &PublishedTrack{ - Owner: participantID, - Info: info, - Layers: layers, - Metadata: metadata, - OutputTrack: outputTrack, - Subscriptions: make(map[ID]subscription.Subscription), + // If this is a new track, let's add it to the list of published and inform participants. + if published, found := t.publishedTracks[remoteTrack.ID()]; found { + if err := published.AddPublisher(remoteTrack); err != nil { + return err } - return + return nil } - // If it's just a new layer, let's add it to the list of layers of the existing published track. - fn := func(layer webrtc_ext.SimulcastLayer) bool { return layer == simulcast } - if simulcast != webrtc_ext.SimulcastLayerNone && slices.IndexFunc(track.Layers, fn) == -1 { - track.Layers = append(track.Layers, simulcast) - t.publishedTracks[info.TrackID] = track + published, err := track.NewPublishedTrack( + participantID, + participant.Peer.RequestKeyFrame, + remoteTrack, + metadata, + participant.Logger, + ) + if err != nil { + return err } + + // Wait for the track to complete and inform the conference about it. + go func() { + // Wait for the track to complete. + <-published.Done() + + // Inform the conference that the track is gone. Or stop the go-routine if the conference stopped. + select { + case t.publishedTrackStopped <- TrackStoppedMessage{remoteTrack.ID(), participantID}: + case <-t.conferenceEnded: + } + }() + + t.publishedTracks[remoteTrack.ID()] = published + return nil } // Iterates over published tracks and calls a closure upon each track info. func (t *Tracker) ForEachPublishedTrackInfo(fn func(ID, webrtc_ext.TrackInfo)) { for _, track := range t.publishedTracks { - fn(track.Owner, track.Info) + fn(track.Owner(), track.Info()) } } // Updates metadata associated with a given track. -func (t *Tracker) UpdatePublishedTrackMetadata(id TrackID, metadata TrackMetadata) { +func (t *Tracker) UpdatePublishedTrackMetadata(id track.TrackID, metadata track.TrackMetadata) { if track, found := t.publishedTracks[id]; found { - track.Metadata = metadata + track.SetMetadata(metadata) t.publishedTracks[id] = track } } // Informs the tracker that one of the previously published tracks is gone. -func (t *Tracker) RemovePublishedTrack(id TrackID) { +func (t *Tracker) RemovePublishedTrack(id track.TrackID) { if publishedTrack, found := t.publishedTracks[id]; found { - // Iterate over all subscriptions and end them. - for subscriberID, subscription := range publishedTrack.Subscriptions { - subscription.Unsubscribe() - delete(publishedTrack.Subscriptions, subscriberID) - } - + publishedTrack.Stop() delete(t.publishedTracks, id) } } // Subscribes a given participant to the track. -func (t *Tracker) Subscribe(participantID ID, trackID TrackID, requirements TrackMetadata) error { +func (t *Tracker) Subscribe(participantID ID, trackID track.TrackID, requirements track.TrackMetadata) error { // Check if the participant exists that wants to subscribe exists. participant := t.participants[participantID] if participant == nil { return fmt.Errorf("participant %s does not exist", participantID) } - // Check if the track that we want to subscribe to exists. + // Check if the track that we want to subscribe exists. published := t.publishedTracks[trackID] if published == nil { return fmt.Errorf("track %s does not exist", trackID) } - // Calculate the desired simulcast layer. - desiredLayer := published.GetOptimalLayer(requirements.MaxWidth, requirements.MaxHeight) - - // If the subscription exists, let's see if we need to update it. - if sub := published.Subscriptions[participantID]; sub != nil { - if sub.Simulcast() != desiredLayer { - sub.SwitchLayer(desiredLayer) - return nil - } - - return fmt.Errorf("subscription already exists and up-to-date") - } - - // Find the owner of the track that we're trying to subscribe to. - owner := t.participants[published.Owner] - if owner == nil { - return fmt.Errorf("owner of the track %s does not exist", published.Info.TrackID) - } - - var ( - sub subscription.Subscription - err error - ) - - // Subscription does not exist, so let's create it. - switch published.Info.Kind { - case webrtc.RTPCodecTypeVideo: - sub, err = subscription.NewVideoSubscription( - published.Info, - desiredLayer, - participant.Peer, - func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { - return owner.Peer.RequestKeyFrame(track, simulcast) - }, - participant.Logger, - ) - case webrtc.RTPCodecTypeAudio: - sub, err = subscription.NewAudioSubscription(published.OutputTrack, participant.Peer) - } - - // If there was an error, let's return it. - if err != nil { + // Subscribe to the track. + if err := published.Subscribe(participantID, participant.Peer, requirements, participant.Logger); err != nil { return err } - // Add the subscription to the list of subscriptions. - published.Subscriptions[participantID] = sub - return nil } // Unsubscribes a given `participantID` from the track. -func (t *Tracker) Unsubscribe(participantID ID, trackID TrackID) { +func (t *Tracker) Unsubscribe(participantID ID, trackID track.TrackID) { if published := t.publishedTracks[trackID]; published != nil { - if sub := published.Subscriptions[participantID]; sub != nil { - sub.Unsubscribe() - delete(published.Subscriptions, participantID) - } - } -} - -// Processes an RTP packet received on a given track. -func (t *Tracker) ProcessRTP(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer, packet *rtp.Packet) { - if published := t.publishedTracks[info.TrackID]; published != nil { - for _, sub := range published.Subscriptions { - if sub.Simulcast() == simulcast { - if err := sub.WriteRTP(*packet); err != nil { - logrus.Errorf("Dropping an RTP packet on %s (%s): %s", info.TrackID, simulcast, err) - } - } - } + published.Unsubscribe(participantID) } } diff --git a/pkg/conference/peer_message_processing.go b/pkg/conference/peer_message_processing.go index 264ee43..c5e259c 100644 --- a/pkg/conference/peer_message_processing.go +++ b/pkg/conference/peer_message_processing.go @@ -2,6 +2,7 @@ package conference import ( "github.com/matrix-org/waterfall/pkg/conference/participant" + published "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" "maunium.net/go/mautrix/event" @@ -23,23 +24,20 @@ func (c *Conference) processLeftTheCallMessage(sender participant.ID, msg peer.L } func (c *Conference) processNewTrackPublishedMessage(sender participant.ID, msg peer.NewTrackPublished) { - c.newLogger(sender).Infof("Published new track: %s (%v)", msg.TrackID, msg.SimulcastLayer) + id := msg.RemoteTrack.ID() + c.newLogger(sender).Infof("Published new track: %s (%v)", id, msg.RemoteTrack.RID()) // Find metadata for a given track. - trackMetadata := streamIntoTrackMetadata(c.streamsMetadata)[msg.TrackID] + trackMetadata := streamIntoTrackMetadata(c.streamsMetadata)[id] // If a new track has been published, we inform everyone about new track available. - c.tracker.AddPublishedTrack(sender, msg.TrackInfo, msg.SimulcastLayer, trackMetadata, msg.OutputTrack) + c.tracker.AddPublishedTrack(sender, msg.RemoteTrack, trackMetadata) c.resendMetadataToAllExcept(sender) } -func (c *Conference) processRTPPacketReceivedMessage(msg peer.RTPPacketReceived) { - c.tracker.ProcessRTP(msg.TrackInfo, msg.SimulcastLayer, msg.Packet) -} - -func (c *Conference) processPublishedTrackFailedMessage(sender participant.ID, msg peer.PublishedTrackFailed) { - c.newLogger(sender).Infof("Failed published track: %s", msg.TrackID) - c.tracker.RemovePublishedTrack(msg.TrackID) +func (c *Conference) processPublishedTrackFailedMessage(sender participant.ID, trackID published.TrackID) { + c.newLogger(sender).Infof("Failed published track: %s", trackID) + c.tracker.RemovePublishedTrack(trackID) c.resendMetadataToAllExcept(sender) } @@ -163,13 +161,13 @@ func (c *Conference) processTrackSubscriptionMessage( for _, track := range msg.Subscribe { p.Logger.Debugf("Subscribing to track %s", track.TrackID) - requirements := participant.TrackMetadata{track.Width, track.Height} + requirements := published.TrackMetadata{track.Width, track.Height} if err := c.tracker.Subscribe(p.ID, track.TrackID, requirements); err != nil { p.Logger.Errorf("Failed to subscribe to track %s: %v", track.TrackID, err) continue } - p.Logger.Debugf("Subscribed to track %s", track.TrackID) + p.Logger.Infof("Subscribed to track %s", track.TrackID) } } diff --git a/pkg/conference/processing.go b/pkg/conference/processing.go index eecfc26..06208b1 100644 --- a/pkg/conference/processing.go +++ b/pkg/conference/processing.go @@ -21,6 +21,8 @@ func (c *Conference) processMessages(signalDone chan struct{}) { c.processPeerMessage(msg) case msg := <-c.matrixEvents: c.processMatrixMessage(msg) + case msg := <-c.publishedTrackStopped: + c.processPublishedTrackFailedMessage(msg.OwnerID, msg.TrackID) } // If there are no more participants, stop the conference. @@ -42,10 +44,6 @@ func (c *Conference) processPeerMessage(message channel.Message[participant.ID, c.processLeftTheCallMessage(message.Sender, msg) case peer.NewTrackPublished: c.processNewTrackPublishedMessage(message.Sender, msg) - case peer.RTPPacketReceived: - c.processRTPPacketReceivedMessage(msg) - case peer.PublishedTrackFailed: - c.processPublishedTrackFailedMessage(message.Sender, msg) case peer.NewICECandidate: c.processNewICECandidateMessage(message.Sender, msg) case peer.ICEGatheringComplete: diff --git a/pkg/conference/publisher/publisher.go b/pkg/conference/publisher/publisher.go new file mode 100644 index 0000000..fdd102f --- /dev/null +++ b/pkg/conference/publisher/publisher.go @@ -0,0 +1,116 @@ +package publisher + +import ( + "errors" + "sync" + + "github.com/pion/rtp" + "github.com/sirupsen/logrus" +) + +var ErrSubscriptionExists = errors.New("subscription already exists") + +type Subscription interface { + // WriteRTP **must not** block (wait on I/O). + WriteRTP(packet rtp.Packet) error +} + +type Track interface { + // ReadPacket **may** block (wait on I/O). + ReadPacket() (*rtp.Packet, error) +} + +// An abstract publisher that reads the packets from the track and forwards them to all subscribers. +type Publisher struct { + logger *logrus.Entry + + mu sync.Mutex + track Track + subscriptions map[Subscription]struct{} +} + +func NewPublisher( + track Track, + stop <-chan struct{}, + log *logrus.Entry, +) (*Publisher, <-chan struct{}) { + // Create a done channel, so that we can signal the caller when we're done. + done := make(chan struct{}) + + publisher := &Publisher{ + logger: log, + track: track, + subscriptions: make(map[Subscription]struct{}), + } + + // Start a goroutine that will read RTP packets from the remote track. + // We run the publisher until we receive a stop signal or an error occurs. + go func() { + defer close(done) + for { + // Check if we were signaled to stop. + select { + case <-stop: + return + default: + if err := publisher.forwardPacket(); err != nil { + log.Errorf("failed to read the frame from the track %s", err) + return + } + } + } + }() + + return publisher, done +} + +func (p *Publisher) AddSubscription(subscription Subscription) { + p.mu.Lock() + defer p.mu.Unlock() + + if _, ok := p.subscriptions[subscription]; ok { + return + } + + p.subscriptions[subscription] = struct{}{} +} + +func (p *Publisher) RemoveSubscription(subscription Subscription) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.subscriptions, subscription) +} + +func (p *Publisher) GetTrack() Track { + p.mu.Lock() + defer p.mu.Unlock() + return p.track +} + +func (p *Publisher) ReplaceTrack(track Track) { + p.mu.Lock() + defer p.mu.Unlock() + p.track = track +} + +// Reads a single packet from the remote track and forwards it to all subscribers. +func (p *Publisher) forwardPacket() error { + track := p.GetTrack() + + packet, err := track.ReadPacket() + if err != nil { + return err + } + + p.mu.Lock() + defer p.mu.Unlock() + + // Write the packet to all subscribers. + for subscription := range p.subscriptions { + if err := subscription.WriteRTP(*packet); err != nil { + p.logger.Warnf("packet dropped on the subscription: %s", err) + } + } + + return nil +} diff --git a/pkg/conference/publisher/track.go b/pkg/conference/publisher/track.go new file mode 100644 index 0000000..658da43 --- /dev/null +++ b/pkg/conference/publisher/track.go @@ -0,0 +1,18 @@ +package publisher + +import ( + "github.com/pion/rtp" + "github.com/pion/webrtc/v3" +) + +// Wrapper for the `webrtc.TrackRemote`. +type RemoteTrack struct { + // The underlying `webrtc.TrackRemote`. + Track *webrtc.TrackRemote +} + +// Implement the `Track` interface for the `webrtc.TrackRemote`. +func (t *RemoteTrack) ReadPacket() (*rtp.Packet, error) { + packet, _, err := t.Track.ReadRTP() + return packet, err +} diff --git a/pkg/conference/start.go b/pkg/conference/start.go index cc749c3..521a97d 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -38,16 +38,20 @@ func StartConference( userID id.UserID, inviteEvent *event.CallInviteEventContent, ) (<-chan struct{}, error) { + signalDone := make(chan struct{}) + + tracker, publishedTrackStopped := participant.NewParticipantTracker(signalDone) conference := &Conference{ - id: confID, - config: config, - connectionFactory: peerConnectionFactory, - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), - matrixWorker: newMatrixWorker(signaling), - tracker: *participant.NewParticipantTracker(), - streamsMetadata: make(event.CallSDPStreamMetadata), - peerMessages: make(chan channel.Message[participant.ID, peer.MessageContent], 100), - matrixEvents: matrixEvents, + id: confID, + config: config, + connectionFactory: peerConnectionFactory, + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + matrixWorker: newMatrixWorker(signaling), + tracker: tracker, + streamsMetadata: make(event.CallSDPStreamMetadata), + peerMessages: make(chan channel.Message[participant.ID, peer.MessageContent], 100), + matrixEvents: matrixEvents, + publishedTrackStopped: publishedTrackStopped, } participantID := participant.ID{UserID: userID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} @@ -56,7 +60,6 @@ func StartConference( } // Start conference "main loop". - signalDone := make(chan struct{}) go conference.processMessages(signalDone) return signalDone, nil diff --git a/pkg/conference/state.go b/pkg/conference/state.go index bf62f18..8c58f05 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -3,6 +3,7 @@ package conference import ( "github.com/matrix-org/waterfall/pkg/channel" "github.com/matrix-org/waterfall/pkg/conference/participant" + published "github.com/matrix-org/waterfall/pkg/conference/track" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/sirupsen/logrus" @@ -11,19 +12,19 @@ import ( // A single conference. Call and conference mean the same in context of Matrix. type Conference struct { - id string - config Config - logger *logrus.Entry - conferenceDone chan<- struct{} + id string + config Config + logger *logrus.Entry connectionFactory *webrtc_ext.PeerConnectionFactory matrixWorker *matrixWorker - tracker participant.Tracker + tracker *participant.Tracker streamsMetadata event.CallSDPStreamMetadata - peerMessages chan channel.Message[participant.ID, peer.MessageContent] - matrixEvents <-chan MatrixMessage + peerMessages chan channel.Message[participant.ID, peer.MessageContent] + matrixEvents <-chan MatrixMessage + publishedTrackStopped <-chan participant.TrackStoppedMessage } func (c *Conference) getParticipant(id participant.ID) *participant.Participant { @@ -114,11 +115,11 @@ func (c *Conference) updateMetadata(metadata event.CallSDPStreamMetadata) { func streamIntoTrackMetadata( streamMetadata event.CallSDPStreamMetadata, -) map[participant.TrackID]participant.TrackMetadata { - tracksMetadata := make(map[participant.TrackID]participant.TrackMetadata) +) map[published.TrackID]published.TrackMetadata { + tracksMetadata := make(map[published.TrackID]published.TrackMetadata) for _, metadata := range streamMetadata { for id, track := range metadata.Tracks { - tracksMetadata[id] = participant.TrackMetadata{ + tracksMetadata[id] = published.TrackMetadata{ MaxWidth: track.Width, MaxHeight: track.Height, } diff --git a/pkg/conference/subscription/video.go b/pkg/conference/subscription/video.go index 9a21892..f8d58f1 100644 --- a/pkg/conference/subscription/video.go +++ b/pkg/conference/subscription/video.go @@ -16,7 +16,7 @@ import ( "github.com/sirupsen/logrus" ) -type RequestKeyFrameFn = func(track webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error +type RequestKeyFrameFn = func(simulcast webrtc_ext.SimulcastLayer) error type VideoSubscription struct { rtpSender *webrtc.RTPSender @@ -71,12 +71,14 @@ func NewVideoSubscription( // Configure the worker for the subscription. workerConfig := worker.Config[rtp.Packet]{ - ChannelSize: 32, - Timeout: 3 * time.Second, + ChannelSize: 16, // We really don't need a large buffer here, just to account for spikes. + Timeout: 3 * time.Second, // When do we assume the subscription is stalled. OnTimeout: func() { layer := webrtc_ext.SimulcastLayer(subscription.currentLayer.Load()) + // TODO: At this point we probably need to send some message back + // to the conference and switch the quality of remove the + // subscription. This must not happen under normal circumstances. logger.Warnf("No RTP on subscription %s (%s)", subscription.info.TrackID, layer) - subscription.requestKeyFrame() }, OnTask: workerState.handlePacket, } @@ -107,7 +109,7 @@ func (s *VideoSubscription) WriteRTP(packet rtp.Packet) error { func (s *VideoSubscription) SwitchLayer(simulcast webrtc_ext.SimulcastLayer) { s.logger.Infof("Switching layer on %s to %s", s.info.TrackID, simulcast) s.currentLayer.Store(int32(simulcast)) - s.requestKeyFrame() + s.requestKeyFrameFn(simulcast) } func (s *VideoSubscription) TrackInfo() webrtc_ext.TrackInfo { @@ -143,10 +145,7 @@ func (s *VideoSubscription) readRTCP() { } func (s *VideoSubscription) requestKeyFrame() { - layer := webrtc_ext.SimulcastLayer(s.currentLayer.Load()) - if err := s.requestKeyFrameFn(s.info, layer); err != nil { - s.logger.Errorf("Failed to request key frame: %s", err) - } + s.requestKeyFrameFn(webrtc_ext.SimulcastLayer(s.currentLayer.Load())) } // Internal state of a worker that runs in its own goroutine. diff --git a/pkg/conference/track/internal.go b/pkg/conference/track/internal.go new file mode 100644 index 0000000..4bf959b --- /dev/null +++ b/pkg/conference/track/internal.go @@ -0,0 +1,103 @@ +package track + +import ( + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/matrix-org/waterfall/pkg/worker" + "github.com/pion/webrtc/v3" +) + +type trackOwner[SubscriberID comparable] struct { + owner SubscriberID + requestKeyFrame func(track *webrtc.TrackRemote) error +} + +type audioTrack struct { + // The sink of this audio track packets. + outputTrack *webrtc.TrackLocalStaticRTP +} + +type videoTrack struct { + // Publishers of each video layer. + publishers map[webrtc_ext.SimulcastLayer]*publisher.Publisher + // Key frame request handler. + keyframeHandler *worker.Worker[webrtc_ext.SimulcastLayer] +} + +// Forward audio packets from the source track to the destination track. +func forward(sender *webrtc.TrackRemote, receiver *webrtc.TrackLocalStaticRTP, stop <-chan struct{}) error { + for { + // Read the data from the remote track. + packet, _, readErr := sender.ReadRTP() + if readErr != nil { + return readErr + } + + // Write the data to the local track. + if writeErr := receiver.WriteRTP(packet); writeErr != nil { + return writeErr + } + + // Check if we need to stop processing packets. + select { + case <-stop: + return nil + default: + } + } +} + +func (p *PublishedTrack[SubscriberID]) addVideoPublisher(track *webrtc.TrackRemote) { + pub, done := publisher.NewPublisher(&publisher.RemoteTrack{track}, p.stopPublishers, p.logger) + simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) + p.video.publishers[simulcast] = pub + + // Listen on `done` and remove the track once it's done. + p.activePublishers.Add(1) + go func() { + defer p.activePublishers.Done() + <-done + + p.mutex.Lock() + defer p.mutex.Unlock() + + // Remove the publisher once it's gone. + delete(p.video.publishers, simulcast) + + // Find any other available layer, so that we can switch subscriptions that lost their publisher + // to a new publisher (at least they'll get some data). + var ( + availableLayer webrtc_ext.SimulcastLayer + availablePublisher *publisher.Publisher + ) + for layer, pub := range p.video.publishers { + availableLayer = layer + availablePublisher = pub + break + } + + // Now iterate over all subscriptions and find those that are now lost due to the publisher being away. + for subID, sub := range p.subscriptions { + if sub.Simulcast() == simulcast { + // If there is some other publisher on a different layer, let's switch to it + if availablePublisher != nil { + sub.SwitchLayer(availableLayer) + pub.AddSubscription(sub) + } else { + // Otherwise, let's just remove the subscription. + sub.Unsubscribe() + delete(p.subscriptions, subID) + } + } + } + }() +} + +func (p *PublishedTrack[SubscriberID]) isClosed() bool { + select { + case <-p.done: + return true + default: + return false + } +} diff --git a/pkg/conference/track/keyframe.go b/pkg/conference/track/keyframe.go new file mode 100644 index 0000000..37fdaca --- /dev/null +++ b/pkg/conference/track/keyframe.go @@ -0,0 +1,42 @@ +package track + +import ( + "fmt" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/pion/webrtc/v3" +) + +func (p *PublishedTrack[SubscriberID]) handleKeyFrameRequest(simulcast webrtc_ext.SimulcastLayer) error { + publisher := p.getPublisher(simulcast) + if publisher == nil { + return fmt.Errorf("publisher with simulcast %s not found", simulcast) + } + + track, err := extractRemoteTrack(publisher) + if err != nil { + return err + } + + return p.owner.requestKeyFrame(track) +} + +func (p *PublishedTrack[SubscriberID]) getPublisher(simulcast webrtc_ext.SimulcastLayer) *publisher.Publisher { + p.mutex.Lock() + defer p.mutex.Unlock() + + // Get the track that we need to request a key frame for. + return p.video.publishers[simulcast] +} + +func extractRemoteTrack(pub *publisher.Publisher) (*webrtc.TrackRemote, error) { + // Get the track that we need to request a key frame for. + track := pub.GetTrack() + remoteTrack, ok := track.(*publisher.RemoteTrack) + if !ok { + return nil, fmt.Errorf("not a remote track in publisher") + } + + return remoteTrack.Track, nil +} diff --git a/pkg/conference/participant/track.go b/pkg/conference/track/simulcast.go similarity index 60% rename from pkg/conference/participant/track.go rename to pkg/conference/track/simulcast.go index fd68cc6..b481f9e 100644 --- a/pkg/conference/participant/track.go +++ b/pkg/conference/track/simulcast.go @@ -1,41 +1,28 @@ -package participant +package track import ( - "github.com/matrix-org/waterfall/pkg/conference/subscription" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/webrtc/v3" - "golang.org/x/exp/slices" ) -type TrackID = string - -// Represents a track that a peer has published (has already started sending to the SFU). -type PublishedTrack struct { - // Owner of a published track. - Owner ID - // Info about the track. - Info webrtc_ext.TrackInfo - // Available simulcast Layers. - Layers []webrtc_ext.SimulcastLayer - // Track metadata. - Metadata TrackMetadata - // Output track (if any). I.e. a track that would contain all RTP packets - // of the given published track. Currently only audio tracks will have it. - OutputTrack *webrtc.TrackLocalStaticRTP - // All available subscriptions for this particular track. - Subscriptions map[ID]subscription.Subscription +// Metadata that we have received about this track from a user. +// This metadata is only set for video tracks at the moment. +type TrackMetadata struct { + MaxWidth, MaxHeight int } // Calculate the layer that we can use based on the requirements passed as parameters and available layers. -func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) webrtc_ext.SimulcastLayer { - // Audio track. For them we don't have any simulcast. We also don't have any simulcast for video - // if there was no simulcast enabled at all. - if p.Info.Kind == webrtc.RTPCodecTypeAudio || len(p.Layers) == 0 { +func getOptimalLayer( + layers map[webrtc_ext.SimulcastLayer]struct{}, + metadata TrackMetadata, + requestedWidth, requestedHeight int, +) webrtc_ext.SimulcastLayer { + // If we don't have any layers available, then there is no simulcast. + if _, found := layers[webrtc_ext.SimulcastLayerNone]; found || len(layers) == 0 { return webrtc_ext.SimulcastLayerNone } // Video track. Calculate the optimal layer closest to the requested resolution. - desiredLayer := calculateDesiredLayer(p.Metadata.MaxWidth, p.Metadata.MaxHeight, requestedWidth, requestedHeight) + desiredLayer := calculateDesiredLayer(metadata.MaxWidth, metadata.MaxHeight, requestedWidth, requestedHeight) // Ideally, here we would need to send an error if the desired layer is not available, but we don't // have a way to do it. So we just return the closest available layer. @@ -48,12 +35,8 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) we // More Go boilerplate. for _, desiredLayer := range priority { - layerIndex := slices.IndexFunc(p.Layers, func(simulcast webrtc_ext.SimulcastLayer) bool { - return simulcast == desiredLayer - }) - - if layerIndex != -1 { - return p.Layers[layerIndex] + if _, found := layers[desiredLayer]; found { + return desiredLayer } } @@ -62,12 +45,6 @@ func (p *PublishedTrack) GetOptimalLayer(requestedWidth, requestedHeight int) we return webrtc_ext.SimulcastLayerLow } -// Metadata that we have received about this track from a user. -// This metadata is only set for video tracks at the moment. -type TrackMetadata struct { - MaxWidth, MaxHeight int -} - // Calculates the optimal layer closest to the requested resolution. We assume that the full resolution is the // maximum resolution that we can get from the user. We assume that a medium quality layer is half the size of // the video (**but not half of the resolution**). I.e. medium quality is high quality divided by 4. And low diff --git a/pkg/conference/track/track.go b/pkg/conference/track/track.go new file mode 100644 index 0000000..163b21e --- /dev/null +++ b/pkg/conference/track/track.go @@ -0,0 +1,267 @@ +package track + +import ( + "fmt" + "sync" + "time" + + "github.com/matrix-org/waterfall/pkg/conference/publisher" + "github.com/matrix-org/waterfall/pkg/conference/subscription" + "github.com/matrix-org/waterfall/pkg/webrtc_ext" + "github.com/matrix-org/waterfall/pkg/worker" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" +) + +type TrackID = string + +// Represents a track that a peer has published (has already started sending to the SFU). +type PublishedTrack[SubscriberID comparable] struct { + // Logger. + logger *logrus.Entry + // Info about the track. + info webrtc_ext.TrackInfo + // Owner of a published track. + owner trackOwner[SubscriberID] + + // We must protect the data with a mutex since we want the `PublishedTrack` to remain thread-safe. + mutex sync.Mutex + // Currently active subscriptions for this track. + subscriptions map[SubscriberID]subscription.Subscription + // Audio track data. The content will be `nil` if it's not an audio track. + audio *audioTrack + // Video track. The content will be `nil` if it's not a video track. + video *videoTrack + // Track metadata. + metadata TrackMetadata + + // Wait group for all active publishers. + activePublishers *sync.WaitGroup + // A signal to publishers **to stop** them all. + stopPublishers chan struct{} + // A aignal to inform the caller that all publishers of this track **have been stopped**. + done chan struct{} +} + +func NewPublishedTrack[SubscriberID comparable]( + ownerID SubscriberID, + requestKeyFrame func(track *webrtc.TrackRemote) error, + track *webrtc.TrackRemote, + metadata TrackMetadata, + logger *logrus.Entry, +) (*PublishedTrack[SubscriberID], error) { + published := &PublishedTrack[SubscriberID]{ + logger: logger, + info: webrtc_ext.TrackInfoFromTrack(track), + owner: trackOwner[SubscriberID]{ownerID, requestKeyFrame}, + subscriptions: make(map[SubscriberID]subscription.Subscription), + audio: &audioTrack{outputTrack: nil}, + video: &videoTrack{publishers: make(map[webrtc_ext.SimulcastLayer]*publisher.Publisher)}, + metadata: metadata, + activePublishers: &sync.WaitGroup{}, + stopPublishers: make(chan struct{}), + done: make(chan struct{}), + } + + switch published.info.Kind { + case webrtc.RTPCodecTypeAudio: + // Create a local track, all our SFU clients that are subscribed to this + // peer (publisher) wil be fed via this track. + localTrack, err := webrtc.NewTrackLocalStaticRTP( + track.Codec().RTPCodecCapability, + track.ID(), + track.StreamID(), + ) + if err != nil { + return nil, err + } + + published.audio.outputTrack = localTrack + + // Start audio publisher in a separate goroutine. + published.activePublishers.Add(1) + go func() { + defer published.activePublishers.Done() + if err := forward(track, localTrack, published.stopPublishers); err != nil { + logger.Errorf("audio publisher stopped: %s", err) + } + }() + + case webrtc.RTPCodecTypeVideo: + // Configure and start a worker to process incoming key frame requests. + workerConfig := worker.Config[webrtc_ext.SimulcastLayer]{ + ChannelSize: 16, + Timeout: 1 * time.Hour, + OnTimeout: func() {}, + OnTask: func(simulcast webrtc_ext.SimulcastLayer) { + published.handleKeyFrameRequest(simulcast) + }, + } + + worker := worker.StartWorker[webrtc_ext.SimulcastLayer](workerConfig) + published.video.keyframeHandler = worker + + // Start video publisher. + published.addVideoPublisher(track) + } + + // Wait for all publishers to stop. + go func() { + defer close(published.done) + published.activePublishers.Wait() + }() + + return published, nil +} + +// Adds a new publisher to the existing `PublishedTrack`, this happens if we +// have multiple qualities (layers) on a single track. +func (p *PublishedTrack[SubscriberID]) AddPublisher(track *webrtc.TrackRemote) error { + if p.isClosed() { + return fmt.Errorf("track is already closed") + } + + info := webrtc_ext.TrackInfoFromTrack(track) + if info.TrackID != p.info.TrackID || p.info.Kind.String() != info.Kind.String() { + return fmt.Errorf("track mismatch") + } + + // Such publisher already exists. Let's replace the track that provides frames with a new one. + simulcast := webrtc_ext.RIDToSimulcastLayer(track.RID()) + + // Lock the mutex since we access the publishers from multiple threads. + p.mutex.Lock() + defer p.mutex.Unlock() + + // If the publisher for this track already exists, let's replace the track. This may happen during + // the negotiation when the SSRC changes and Pion fires a new track for the track that has already + // been published. + if pub := p.video.publishers[simulcast]; pub != nil { + pub.ReplaceTrack(&publisher.RemoteTrack{track}) + return nil + } + + // Add a publisher and start polling it. + p.addVideoPublisher(track) + return nil +} + +// Stops the published track and all related publishers. You should not use the +// `PublishedTrack` after calling this method. +func (p *PublishedTrack[SubscriberID]) Stop() { + // Command all publishers to stop, unless already stopped. + if !p.isClosed() { + close(p.stopPublishers) + } +} + +// Create a new subscription for a given subscriber or update the existing one if necessary. +func (p *PublishedTrack[SubscriberID]) Subscribe( + subscriberID SubscriberID, + controller subscription.SubscriptionController, + requirements TrackMetadata, + logger *logrus.Entry, +) error { + if p.isClosed() { + return fmt.Errorf("track is already closed") + } + + // Lock the mutex as we access subscriptions and publishers from multiple threads. + p.mutex.Lock() + defer p.mutex.Unlock() + + // Let's calculate the desired simulcast layer (if any). + var layer webrtc_ext.SimulcastLayer + if p.info.Kind == webrtc.RTPCodecTypeVideo { + layers := make(map[webrtc_ext.SimulcastLayer]struct{}, len(p.video.publishers)) + for key := range p.video.publishers { + layers[key] = struct{}{} + } + layer = getOptimalLayer(layers, p.metadata, requirements.MaxWidth, requirements.MaxHeight) + } + + // If the subscription exists, let's see if we need to update it. + if sub := p.subscriptions[subscriberID]; sub != nil { + currentLayer := sub.Simulcast() + + // If we do, let's switch the layer. + if currentLayer != layer { + p.video.publishers[currentLayer].RemoveSubscription(sub) + sub.SwitchLayer(layer) + p.video.publishers[layer].AddSubscription(sub) + } + + // Subsription is up-to-date, nothing to change. + return nil + } + + var ( + sub subscription.Subscription + err error + ) + + // Subscription does not exist, so let's create it. + switch p.info.Kind { + case webrtc.RTPCodecTypeVideo: + handler := func(simulcast webrtc_ext.SimulcastLayer) error { + return p.video.keyframeHandler.Send(simulcast) + } + sub, err = subscription.NewVideoSubscription(p.info, layer, controller, handler, logger) + case webrtc.RTPCodecTypeAudio: + sub, err = subscription.NewAudioSubscription(p.audio.outputTrack, controller) + } + + // If there was an error, let's return it. + if err != nil { + return err + } + + // Add the subscription to the list of subscriptions. + p.subscriptions[subscriberID] = sub + + // And if it's a video subscription, add it to the list of subscriptions that get the feed from the publisher. + if p.info.Kind == webrtc.RTPCodecTypeVideo { + p.video.publishers[layer].AddSubscription(sub) + } + + return nil +} + +// Remove subscriptions with a given subscriber id. +func (p *PublishedTrack[SubscriberID]) Unsubscribe(subscriberID SubscriberID) { + p.mutex.Lock() + defer p.mutex.Unlock() + + if sub := p.subscriptions[subscriberID]; sub != nil { + sub.Unsubscribe() + delete(p.subscriptions, subscriberID) + + if p.info.Kind == webrtc.RTPCodecTypeVideo { + p.video.publishers[sub.Simulcast()].RemoveSubscription(sub) + } + } +} + +func (p *PublishedTrack[SubscriberID]) Owner() SubscriberID { + return p.owner.owner +} + +func (p *PublishedTrack[SubscriberID]) Info() webrtc_ext.TrackInfo { + return p.info +} + +func (p *PublishedTrack[SubscriberID]) Done() <-chan struct{} { + return p.done +} + +func (p *PublishedTrack[SubscriberID]) Metadata() TrackMetadata { + p.mutex.Lock() + defer p.mutex.Unlock() + return p.metadata +} + +func (p *PublishedTrack[SubscriberID]) SetMetadata(metadata TrackMetadata) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.metadata = metadata +} diff --git a/pkg/conference/participant/track_test.go b/pkg/conference/track/track_test.go similarity index 73% rename from pkg/conference/participant/track_test.go rename to pkg/conference/track/track_test.go index 3d8f9a1..9825660 100644 --- a/pkg/conference/participant/track_test.go +++ b/pkg/conference/track/track_test.go @@ -1,11 +1,9 @@ -package participant_test +package track //nolint:testpackage import ( "testing" - "github.com/matrix-org/waterfall/pkg/conference/participant" "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/webrtc/v3" ) func TestGetOptimalLayer(t *testing.T) { @@ -43,33 +41,29 @@ func TestGetOptimalLayer(t *testing.T) { {layers(high), 1280, 720, 200, 200, high}, } - mock := participant.PublishedTrack{ - Info: webrtc_ext.TrackInfo{ - Kind: webrtc.RTPCodecTypeVideo, - }, - } - for _, c := range cases { - mock.Layers = c.availableLayers - mock.Metadata.MaxWidth = c.fullWidth - mock.Metadata.MaxHeight = c.fullHeight + metadata := TrackMetadata{ + MaxWidth: c.fullWidth, + MaxHeight: c.fullHeight, + } - optimalLayer := mock.GetOptimalLayer(c.desiredWidth, c.desiredHeight) + layers := make(map[webrtc_ext.SimulcastLayer]struct{}, len(c.availableLayers)) + for _, layer := range c.availableLayers { + layers[layer] = struct{}{} + } + + optimalLayer := getOptimalLayer(layers, metadata, c.desiredWidth, c.desiredHeight) if optimalLayer != c.expectedOptimalLayer { t.Errorf("Expected optimal layer %s, got %s", c.expectedOptimalLayer, optimalLayer) } } } -func TestGetOptimalLayerAudio(t *testing.T) { - mock := participant.PublishedTrack{ - Info: webrtc_ext.TrackInfo{ - Kind: webrtc.RTPCodecTypeAudio, - }, - } +func TestGetOptimalLayerNone(t *testing.T) { + layers := make(map[webrtc_ext.SimulcastLayer]struct{}) + metadata := TrackMetadata{} - mock.Layers = []webrtc_ext.SimulcastLayer{webrtc_ext.SimulcastLayerLow} - if mock.GetOptimalLayer(100, 100) != webrtc_ext.SimulcastLayerNone { + if getOptimalLayer(layers, metadata, 100, 100) != webrtc_ext.SimulcastLayerNone { t.Fatal("Expected no simulcast layer for audio") } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 4f67a6d..896c01b 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,8 +1,6 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -18,25 +16,8 @@ type LeftTheCall struct { } type NewTrackPublished struct { - // Information about the track (ID etc). - webrtc_ext.TrackInfo - // SimulcastLayer configuration (can be `None` for non-simulcast tracks and for audio tracks). - SimulcastLayer webrtc_ext.SimulcastLayer - // Output track (if any) that could be used to send data to the peer. Will be `nil` if such - // track does not exist, in which case the caller is expected to listen to `RtpPacketReceived` - // messages. - OutputTrack *webrtc.TrackLocalStaticRTP -} - -type PublishedTrackFailed struct { - webrtc_ext.TrackInfo - SimulcastLayer webrtc_ext.SimulcastLayer -} - -type RTPPacketReceived struct { - webrtc_ext.TrackInfo - SimulcastLayer webrtc_ext.SimulcastLayer - Packet *rtp.Packet + // Remote track that has been published. + RemoteTrack *webrtc.TrackRemote } type NewICECandidate struct { diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 0b20650..7211acf 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -21,7 +21,6 @@ var ( ErrDataChannelNotAvailable = errors.New("data channel is not available") ErrDataChannelNotReady = errors.New("data channel is not ready") ErrCantSubscribeToTrack = errors.New("can't subscribe to track") - ErrTrackNotFound = errors.New("track not found") ) // A wrapped representation of the peer connection (single peer in the call). @@ -84,14 +83,7 @@ func (p *Peer[ID]) Terminate() { } // Request a key frame from the peer connection. -func (p *Peer[ID]) RequestKeyFrame(info webrtc_ext.TrackInfo, simulcast webrtc_ext.SimulcastLayer) error { - // Find the right track. - track := p.state.GetRemoteTrack(info.TrackID, simulcast) - if track == nil { - return ErrTrackNotFound - } - - p.logger.Debugf("Keyframe request: %s (%s)", info.TrackID, simulcast) +func (p *Peer[ID]) RequestKeyFrame(track *webrtc.TrackRemote) error { rtcps := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}} return p.peerConnection.WriteRTCP(rtcps) } diff --git a/pkg/peer/remote_track.go b/pkg/peer/remote_track.go deleted file mode 100644 index 97a82ae..0000000 --- a/pkg/peer/remote_track.go +++ /dev/null @@ -1,88 +0,0 @@ -package peer - -import ( - "errors" - "io" - - "github.com/matrix-org/waterfall/pkg/webrtc_ext" - "github.com/pion/rtp" - "github.com/pion/webrtc/v3" -) - -func (p *Peer[ID]) handleNewVideoTrack( - trackInfo webrtc_ext.TrackInfo, - remoteTrack *webrtc.TrackRemote, - receiver *webrtc.RTPReceiver, -) { - simulcast := webrtc_ext.RIDToSimulcastLayer(remoteTrack.RID()) - - p.handleRemoteTrack(remoteTrack, trackInfo, simulcast, nil, func(packet *rtp.Packet) error { - p.sink.Send(RTPPacketReceived{trackInfo, simulcast, packet}) - return nil - }) -} - -func (p *Peer[ID]) handleNewAudioTrack( - trackInfo webrtc_ext.TrackInfo, - remoteTrack *webrtc.TrackRemote, - receiver *webrtc.RTPReceiver, -) { - // Create a local track, all our SFU clients that are subscribed to this - // peer (publisher) wil be fed via this track. - localTrack, err := webrtc.NewTrackLocalStaticRTP( - remoteTrack.Codec().RTPCodecCapability, - remoteTrack.ID(), - remoteTrack.StreamID(), - ) - if err != nil { - p.logger.WithError(err).Error("failed to create local track") - return - } - - p.handleRemoteTrack(remoteTrack, trackInfo, webrtc_ext.SimulcastLayerNone, localTrack, func(packet *rtp.Packet) error { - if err = localTrack.WriteRTP(packet); err != nil && !errors.Is(err, io.ErrClosedPipe) { - return err - } - return nil - }) -} - -func (p *Peer[ID]) handleRemoteTrack( - remoteTrack *webrtc.TrackRemote, - trackInfo webrtc_ext.TrackInfo, - simulcast webrtc_ext.SimulcastLayer, - outputTrack *webrtc.TrackLocalStaticRTP, - handleRtpFn func(*rtp.Packet) error, -) { - // Notify others that our track has just been published. - p.state.AddRemoteTrack(remoteTrack) - p.sink.Send(NewTrackPublished{trackInfo, simulcast, outputTrack}) - - // Start a go-routine that reads the data from the remote track. - go func() { - // Call this when this goroutine ends. - defer func() { - p.state.RemoveRemoteTrack(remoteTrack) - p.sink.Send(PublishedTrackFailed{trackInfo, simulcast}) - }() - - for { - // Read the data from the remote track. - packet, _, readErr := remoteTrack.ReadRTP() - if readErr != nil { - if readErr == io.EOF { // finished, no more data, no error, inform others - p.logger.Info("remote track closed") - } else { // finished, no more data, but with error, inform others - p.logger.WithError(readErr).Error("failed to read from remote track") - } - return - } - - // Handle the RTP packet. - if err := handleRtpFn(packet); err != nil { - p.logger.WithError(err).Error("failed to handle RTP packet") - return - } - } - }() -} diff --git a/pkg/peer/state/peer_state.go b/pkg/peer/state/peer_state.go index b277139..819dd8d 100644 --- a/pkg/peer/state/peer_state.go +++ b/pkg/peer/state/peer_state.go @@ -3,46 +3,16 @@ package state import ( "sync" - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" ) -type RemoteTrackId struct { - id string - simulcast webrtc_ext.SimulcastLayer -} - type PeerState struct { - mutex sync.Mutex - dataChannel *webrtc.DataChannel - remoteTracks map[RemoteTrackId]*webrtc.TrackRemote + mutex sync.Mutex + dataChannel *webrtc.DataChannel } func NewPeerState() *PeerState { - return &PeerState{ - remoteTracks: make(map[RemoteTrackId]*webrtc.TrackRemote), - } -} - -func (p *PeerState) AddRemoteTrack(track *webrtc.TrackRemote) { - p.mutex.Lock() - defer p.mutex.Unlock() - - p.remoteTracks[RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}] = track -} - -func (p *PeerState) RemoveRemoteTrack(track *webrtc.TrackRemote) { - p.mutex.Lock() - defer p.mutex.Unlock() - - delete(p.remoteTracks, RemoteTrackId{track.ID(), webrtc_ext.RIDToSimulcastLayer(track.RID())}) -} - -func (p *PeerState) GetRemoteTrack(id string, simulcast webrtc_ext.SimulcastLayer) *webrtc.TrackRemote { - p.mutex.Lock() - defer p.mutex.Unlock() - - return p.remoteTracks[RemoteTrackId{id, simulcast}] + return &PeerState{} } func (p *PeerState) SetDataChannel(dc *webrtc.DataChannel) { diff --git a/pkg/peer/webrtc_callbacks.go b/pkg/peer/webrtc_callbacks.go index 31288b8..7bac093 100644 --- a/pkg/peer/webrtc_callbacks.go +++ b/pkg/peer/webrtc_callbacks.go @@ -1,7 +1,6 @@ package peer import ( - "github.com/matrix-org/waterfall/pkg/webrtc_ext" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -9,15 +8,8 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - // Construct a new track info assuming that there is no simulcast. - trackInfo := webrtc_ext.TrackInfoFromTrack(remoteTrack) - - switch trackInfo.Kind { - case webrtc.RTPCodecTypeVideo: - p.handleNewVideoTrack(trackInfo, remoteTrack, receiver) - case webrtc.RTPCodecTypeAudio: - p.handleNewAudioTrack(trackInfo, remoteTrack, receiver) - } + p.logger.WithField("track", remoteTrack).Debug("RTP track received") + p.sink.Send(NewTrackPublished{remoteTrack}) } // A callback that is called once we receive an ICE candidate for this peer connection.