From bf6b9877de23f95a61d4291297d0a4872d514836 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 17 Nov 2022 22:11:56 +0100 Subject: [PATCH 01/62] refactor: clean up main.go a bit --- .golangci.yaml | 1 + src/logger.go | 6 ++++ src/main.go | 91 ++++++++++---------------------------------------- src/matrix.go | 2 +- 4 files changed, 26 insertions(+), 74 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index a27073d..71884b8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -29,4 +29,5 @@ linters: - gomnd # we use status code numbers and for our use case it's not practical - godox # we have TODOs at this stage of the project, enable in future - forbidigo # we use things like fmt.Printf for debugging, enable in future + - wsl # somehow this plugin causes more harm than use as it enables lots of things to be configured without causing spaghetti-code (grouping similar things together) fast: true diff --git a/src/logger.go b/src/logger.go index 2d849fc..e3a99b0 100644 --- a/src/logger.go +++ b/src/logger.go @@ -22,6 +22,12 @@ import ( "github.com/sirupsen/logrus" ) +func initLogging(logTime bool) { + formatter := new(CustomTextFormatter) + formatter.logTime = logTime + logrus.SetFormatter(formatter) +} + type CustomTextFormatter struct { logrus.TextFormatter logTime bool diff --git a/src/main.go b/src/main.go index 80dbaaf..7694f01 100644 --- a/src/main.go +++ b/src/main.go @@ -20,105 +20,50 @@ import ( "flag" "os" "os/signal" - "runtime" - "runtime/pprof" "syscall" "github.com/sirupsen/logrus" ) -func initCPUProfiling(cpuProfile *string) func() { - logrus.Info("initializing CPU profiling") - - file, err := os.Create(*cpuProfile) - if err != nil { - logrus.WithError(err).Fatal("could not create CPU profile") - } - - if err := pprof.StartCPUProfile(file); err != nil { - logrus.WithError(err).Fatal("could not start CPU profile") - } - - return func() { - pprof.StopCPUProfile() - - if err := file.Close(); err != nil { - logrus.WithError(err).Fatal("could not close CPU profile") - } - } -} - -func initMemoryProfiling(memProfile *string) func() { - logrus.Info("initializing memory profiling") - - return func() { - file, err := os.Create(*memProfile) - if err != nil { - logrus.WithError(err).Fatal("could not create memory profile") - } - - runtime.GC() - - if err := pprof.WriteHeapProfile(file); err != nil { - logrus.WithError(err).Fatal("could not write memory profile") - } - - if err = file.Close(); err != nil { - logrus.WithError(err).Fatal("could not close memory profile") - } - } -} - -func initLogging(logTime *bool) { - formatter := new(CustomTextFormatter) - - formatter.logTime = *logTime - - logrus.SetFormatter(formatter) -} - -func killListener(c chan os.Signal, beforeExit []func()) { - <-c - - for _, function := range beforeExit { - function() - } - - defer os.Exit(0) -} - func main() { + // Parse command line flags. var ( logTime = flag.Bool("logTime", false, "whether or not to print time and date in logs") configFilePath = flag.String("config", "config.yaml", "configuration file path") cpuProfile = flag.String("cpuProfile", "", "write CPU profile to `file`") memProfile = flag.String("memProfile", "", "write memory profile to `file`") ) - flag.Parse() - initLogging(logTime) + // Initialize logging subsystem (formatting, global logging framework etc). + initLogging(*logTime) - beforeExit := []func(){} + // Define functions that are called before exiting. + // This is useful to stop the profiler if it's enabled. + deferred_functions := []func(){} if *cpuProfile != "" { - beforeExit = append(beforeExit, initCPUProfiling(cpuProfile)) + deferred_functions = append(deferred_functions, initCPUProfiling(cpuProfile)) } - if *memProfile != "" { - beforeExit = append(beforeExit, initMemoryProfiling(memProfile)) + deferred_functions = append(deferred_functions, initMemoryProfiling(memProfile)) } - // try to handle os interrupt(signal terminated) - //nolint:gomnd + // Handle signal interruptions. c := make(chan os.Signal, 2) signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + for _, function := range deferred_functions { + function() + } + os.Exit(0) + }() - go killListener(c, beforeExit) - + // Load the config file from the environment variable or path. config, err := loadConfig(*configFilePath) if err != nil { logrus.WithError(err).Fatal("could not load config") } - InitMatrix(config) + RunSFU(config) } diff --git a/src/matrix.go b/src/matrix.go index db493ba..d604fa6 100644 --- a/src/matrix.go +++ b/src/matrix.go @@ -25,7 +25,7 @@ import ( const localSessionID = "sfu" -func InitMatrix(config *Config) { +func RunSFU(config *Config) { client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) if err != nil { logrus.WithError(err).Fatal("Failed to create client") From c5c6206cd3304a6c2c1ca7efd684813279fc03f0 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 18 Nov 2022 00:11:19 +0100 Subject: [PATCH 02/62] refactor: renaming modules and functions names Pretty generic work-in-progress state. --- .golangci.yaml | 1 + src/call.go | 8 +-- src/conference.go | 11 ++- src/focus.go | 176 ---------------------------------------------- src/logger.go | 2 +- src/main.go | 8 +-- src/matrix.go | 20 +++--- src/profile.go | 53 ++++++++++++++ src/sfu.go | 173 +++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 251 insertions(+), 201 deletions(-) delete mode 100644 src/focus.go create mode 100644 src/profile.go create mode 100644 src/sfu.go diff --git a/.golangci.yaml b/.golangci.yaml index 71884b8..1ae85b5 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -30,4 +30,5 @@ linters: - godox # we have TODOs at this stage of the project, enable in future - forbidigo # we use things like fmt.Printf for debugging, enable in future - wsl # somehow this plugin causes more harm than use as it enables lots of things to be configured without causing spaghetti-code (grouping similar things together) + - nlreturn # not always practical, it was disabled before strict lints were introduced, then added, now it's clear why it was disabled at the first place :) fast: true diff --git a/src/call.go b/src/call.go index 490c021..e92c853 100644 --- a/src/call.go +++ b/src/call.go @@ -65,7 +65,7 @@ func (c *Call) InitWithInvite(evt *event.Event, client *mautrix.Client) { c.UserID = evt.Sender c.DeviceID = invite.DeviceID // XXX: What if an SFU gets restarted? - c.LocalSessionID = localSessionID + c.LocalSessionID = LocalSessionID c.RemoteSessionID = invite.SenderSessionID c.client = client c.logger = logrus.WithFields(logrus.Fields{ @@ -181,7 +181,7 @@ func (c *Call) onDCAlive() { func (c *Call) onDCMetadata() { c.logger.Info("received DC metadata") - c.conf.SendUpdatedMetadataFromCall(c.CallID) + c.conf.SendUpdatedMetadataFromPeer(c.CallID) } func (c *Call) dataChannelHandler(channel *webrtc.DataChannel) { @@ -287,7 +287,7 @@ func (c *Call) iceCandidateHandler(candidate *webrtc.ICECandidate) { func (c *Call) trackHandler(trackRemote *webrtc.TrackRemote) { NewPublisher(trackRemote, c) - go c.conf.SendUpdatedMetadataFromCall(c.CallID) + go c.conf.SendUpdatedMetadataFromPeer(c.CallID) } func (c *Call) iceConnectionStateHandler(state webrtc.ICEConnectionState) { @@ -436,7 +436,7 @@ func (c *Call) Terminate() { subscriber.Unsubscribe() } - c.conf.SendUpdatedMetadataFromCall(c.CallID) + c.conf.SendUpdatedMetadataFromPeer(c.CallID) } func (c *Call) Hangup(reason event.CallHangupReason) { diff --git a/src/conference.go b/src/conference.go index 1d08064..110fd07 100644 --- a/src/conference.go +++ b/src/conference.go @@ -31,7 +31,7 @@ var ( ) // Configuration for the group conferences (calls). -type ConferenceConfig struct { +type CallConfig struct { // Keep-alive timeout for WebRTC connections. If no keep-alive has been received // from the client for this duration, the connection is considered dead. KeepAliveTimeout int @@ -39,17 +39,16 @@ type ConferenceConfig struct { type Conference struct { ConfID string - Calls map[string]*Call // By callID - Config *ConferenceConfig // TODO: this must be protected by a mutex actually + Calls map[string]*Call // By callID + Config *CallConfig // TODO: this must be protected by a mutex actually mutex sync.RWMutex logger *logrus.Entry Metadata *Metadata } -func NewConference(confID string, config *ConferenceConfig) *Conference { +func NewConference(confID string, config *CallConfig) *Conference { conference := new(Conference) - conference.Config = config conference.ConfID = confID conference.Calls = make(map[string]*Call) @@ -94,7 +93,7 @@ func (c *Conference) RemoveOldCallsByDeviceAndSessionIDs(deviceID id.DeviceID, s return err } -func (c *Conference) SendUpdatedMetadataFromCall(callID string) { +func (c *Conference) SendUpdatedMetadataFromPeer(callID string) { for _, call := range c.Calls { if call.CallID != callID { call.SendDataChannelMessage(event.SFUMessage{Op: event.SFUOperationMetadata}) diff --git a/src/focus.go b/src/focus.go deleted file mode 100644 index c2851c6..0000000 --- a/src/focus.go +++ /dev/null @@ -1,176 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "strings" - "sync" - - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" -) - -var ErrNoSuchConference = errors.New("no such conf") - -type Confs struct { - confsMu sync.RWMutex - confs map[string]*Conference -} - -type Focus struct { - name string - client *mautrix.Client - confs Confs - logger *logrus.Entry - config *ConferenceConfig -} - -func NewFocus(name string, client *mautrix.Client, config *ConferenceConfig) *Focus { - focus := new(Focus) - - focus.name = name - focus.client = client - focus.confs.confs = make(map[string]*Conference) - focus.logger = logrus.WithFields(logrus.Fields{}) - focus.config = config - - return focus -} - -func (f *Focus) GetConf(confID string, create bool) (*Conference, error) { - f.confs.confsMu.Lock() - defer f.confs.confsMu.Unlock() - conference := f.confs.confs[confID] - - if conference == nil { - if create { - conference = NewConference(confID, f.config) - f.confs.confs[confID] = conference - } else { - return nil, ErrNoSuchConference - } - } - - return conference, nil -} - -func (f *Focus) getExistingCall(confID string, callID string) (*Call, error) { - var ( - conf *Conference - call *Call - err error - ) - - if conf, err = f.GetConf(confID, false); err != nil || conf == nil { - f.logger.Printf("failed to get conf %s: %s", confID, err) - return nil, err - } - - if call, err = conf.GetCall(callID, false); err != nil || call == nil { - f.logger.Printf("failed to get call %s: %s", callID, err) - return nil, err - } - - return call, nil -} - -func (f *Focus) onEvent(_ mautrix.EventSource, evt *event.Event) { - // We only care about to-device events - if evt.Type.Class != event.ToDeviceEventType { - return - } - - evtLogger := f.logger.WithFields(logrus.Fields{ - "type": evt.Type.Type, - "user_id": evt.Sender.String(), - "conf_id": evt.Content.Raw["conf_id"], - }) - - if !strings.HasPrefix(evt.Type.Type, "m.call.") && !strings.HasPrefix(evt.Type.Type, "org.matrix.call.") { - evtLogger.Warn("received non-call to-device event") - return - } else if evt.Type.Type != event.ToDeviceCallCandidates.Type && evt.Type.Type != event.ToDeviceCallSelectAnswer.Type { - evtLogger.Info("received to-device event") - } - - if evt.Content.Raw["dest_session_id"] != localSessionID { - evtLogger.WithField("dest_session_id", localSessionID).Warn("SessionID does not match our SessionID - ignoring") - return - } - - var ( - conf *Conference - call *Call - err error - ) - - switch evt.Type.Type { - case event.ToDeviceCallInvite.Type: - invite := evt.Content.AsCallInvite() - if conf, err = f.GetConf(invite.ConfID, true); err != nil { - evtLogger.WithError(err).WithFields(logrus.Fields{ - "conf_id": invite.ConfID, - }).Error("failed to create conf") - - return - } - - if err := conf.RemoveOldCallsByDeviceAndSessionIDs(invite.DeviceID, invite.SenderSessionID); err != nil { - evtLogger.WithError(err).Error("error removing old calls - ignoring call") - return - } - - if call, err = conf.GetCall(invite.CallID, true); err != nil || call == nil { - evtLogger.WithError(err).Error("failed to create call") - return - } - - call.InitWithInvite(evt, f.client) - call.OnInvite(invite) - case event.ToDeviceCallCandidates.Type: - candidates := evt.Content.AsCallCandidates() - if call, err = f.getExistingCall(candidates.ConfID, candidates.CallID); err != nil { - return - } - - call.OnCandidates(candidates) - case event.ToDeviceCallSelectAnswer.Type: - selectAnswer := evt.Content.AsCallSelectAnswer() - if call, err = f.getExistingCall(selectAnswer.ConfID, selectAnswer.CallID); err != nil { - return - } - - call.OnSelectAnswer(selectAnswer) - case event.ToDeviceCallHangup.Type: - hangup := evt.Content.AsCallHangup() - if call, err = f.getExistingCall(hangup.ConfID, hangup.CallID); err != nil { - return - } - - call.OnHangup() - // Events we don't care about - case event.ToDeviceCallNegotiate.Type: - evtLogger.Warn("ignoring event as it should be handled over DC") - case event.ToDeviceCallReject.Type: - case event.ToDeviceCallAnswer.Type: - evtLogger.Warn("ignoring event as we are always the ones answering") - default: - evtLogger.Warn("ignoring unrecognised event") - } -} diff --git a/src/logger.go b/src/logger.go index e3a99b0..9d725ae 100644 --- a/src/logger.go +++ b/src/logger.go @@ -22,7 +22,7 @@ import ( "github.com/sirupsen/logrus" ) -func initLogging(logTime bool) { +func InitLogging(logTime bool) { formatter := new(CustomTextFormatter) formatter.logTime = logTime logrus.SetFormatter(formatter) diff --git a/src/main.go b/src/main.go index 7694f01..ec4bfea 100644 --- a/src/main.go +++ b/src/main.go @@ -36,16 +36,16 @@ func main() { flag.Parse() // Initialize logging subsystem (formatting, global logging framework etc). - initLogging(*logTime) + InitLogging(*logTime) // Define functions that are called before exiting. // This is useful to stop the profiler if it's enabled. deferred_functions := []func(){} if *cpuProfile != "" { - deferred_functions = append(deferred_functions, initCPUProfiling(cpuProfile)) + deferred_functions = append(deferred_functions, InitCPUProfiling(cpuProfile)) } if *memProfile != "" { - deferred_functions = append(deferred_functions, initMemoryProfiling(memProfile)) + deferred_functions = append(deferred_functions, InitMemoryProfiling(memProfile)) } // Handle signal interruptions. @@ -65,5 +65,5 @@ func main() { logrus.WithError(err).Fatal("could not load config") } - RunSFU(config) + RunServer(config) } diff --git a/src/matrix.go b/src/matrix.go index d604fa6..0e67a65 100644 --- a/src/matrix.go +++ b/src/matrix.go @@ -17,15 +17,15 @@ limitations under the License. package main import ( - "fmt" - "github.com/sirupsen/logrus" "maunium.net/go/mautrix" ) -const localSessionID = "sfu" +const LocalSessionID = "sfu" -func RunSFU(config *Config) { +// Starts the Matrix client and connects to the homeserver, +// runs the SFU. Returns only when the sync with Matrix fails. +func RunServer(config *Config) { client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) if err != nil { logrus.WithError(err).Fatal("Failed to create client") @@ -43,10 +43,9 @@ func RunSFU(config *Config) { logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") client.DeviceID = whoami.DeviceID - focus := NewFocus( - fmt.Sprintf("%s (%s)", config.UserID, client.DeviceID), + focus := NewSFU( client, - &ConferenceConfig{KeepAliveTimeout: config.KeepAliveTimeout}, + &CallConfig{KeepAliveTimeout: config.KeepAliveTimeout}, ) syncer, ok := client.Syncer.(*mautrix.DefaultSyncer) @@ -55,10 +54,11 @@ func RunSFU(config *Config) { } syncer.ParseEventContent = true + syncer.OnEvent(focus.onMatrixEvent) - // TODO: E2EE - syncer.OnEvent(focus.onEvent) - + // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU + // as ending here will essentially drop all conferences which may not necessarily + // be what we want for the existing running conferences. if err = client.Sync(); err != nil { logrus.WithError(err).Panic("Sync failed") } diff --git a/src/profile.go b/src/profile.go new file mode 100644 index 0000000..6ca8b45 --- /dev/null +++ b/src/profile.go @@ -0,0 +1,53 @@ +package main + +import ( + "os" + "runtime" + "runtime/pprof" + + "github.com/sirupsen/logrus" +) + +// Initializes CPU profiling and returns a function to stop profiling. +func InitCPUProfiling(cpuProfile *string) func() { + logrus.Info("initializing CPU profiling") + + file, err := os.Create(*cpuProfile) + if err != nil { + logrus.WithError(err).Fatal("could not create CPU profile") + } + + if err := pprof.StartCPUProfile(file); err != nil { + logrus.WithError(err).Fatal("could not start CPU profile") + } + + return func() { + pprof.StopCPUProfile() + + if err := file.Close(); err != nil { + logrus.WithError(err).Fatal("could not close CPU profile") + } + } +} + +// Initializes memory profiling and returns a function to stop profiling. +func InitMemoryProfiling(memProfile *string) func() { + logrus.Info("initializing memory profiling") + + return func() { + file, err := os.Create(*memProfile) + if err != nil { + logrus.WithError(err).Fatal("could not create memory profile") + } + + runtime.GC() + + if err := pprof.WriteHeapProfile(file); err != nil { + logrus.WithError(err).Fatal("could not write memory profile") + } + + if err = file.Close(); err != nil { + logrus.WithError(err).Fatal("could not close memory profile") + } + } +} diff --git a/src/sfu.go b/src/sfu.go new file mode 100644 index 0000000..c183b23 --- /dev/null +++ b/src/sfu.go @@ -0,0 +1,173 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "errors" + + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +var ErrNoSuchConference = errors.New("no such conference") + +// The top-level state of the SFU. +// Note that in Matrix MSCs, the term "focus" is used to refer to the SFU. But since "focus" is a very +// generic name and only makes sense in a certain context, we use the term "SFU" instead to avoid confusion +// given that this particular part is just the SFU logic (and not the "focus" selection algorithm etc). +type SFU struct { + // Matrix client. + client *mautrix.Client + // All calls currently forwarded by this SFU. + conferences map[string]*Conference + // Structured logging for the SFU. + logger *logrus.Entry + // Configuration for the calls. + config *CallConfig +} + +// Creates a new instance of the SFU with the given configuration. +func NewSFU(client *mautrix.Client, config *CallConfig) *SFU { + return &SFU{ + client: client, + conferences: make(map[string]*Conference), + logger: logrus.WithField("module", "sfu"), + config: config, + } +} + +// Returns a conference by its `id`, or creates a new one if it doesn't exist yet. +func (f *SFU) GetOrCreateConference(confID string, create bool) (*Conference, error) { + if conference := f.conferences[confID]; conference != nil { + return conference, nil + } + + if create { + f.logger.Printf("creating new conference %s", confID) + conference := NewConference(confID, f.config) + f.conferences[confID] = conference + return conference, nil + } + + return nil, ErrNoSuchConference +} + +func (f *SFU) GetCall(confID string, callID string) (*Call, error) { + var ( + conf *Conference + call *Call + err error + ) + + if conf, err = f.GetOrCreateConference(confID, false); err != nil || conf == nil { + f.logger.Printf("failed to get conf %s: %s", confID, err) + return nil, err + } + + if call, err = conf.GetCall(callID, false); err != nil || call == nil { + f.logger.Printf("failed to get call %s: %s", callID, err) + return nil, err + } + + return call, nil +} + +// Handles To-Device events that the SFU receives from clients. +func (f *SFU) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { + // We only care about to-device events. + if evt.Type.Class != event.ToDeviceEventType { + f.logger.Warn("ignoring a not to-device event") + return + } + + evtLogger := f.logger.WithFields(logrus.Fields{ + "type": evt.Type.Type, + "user_id": evt.Sender.String(), + "conf_id": evt.Content.Raw["conf_id"], + }) + + if evt.Content.Raw["dest_session_id"] != LocalSessionID { + evtLogger.WithField("dest_session_id", LocalSessionID).Warn("SessionID does not match our SessionID - ignoring") + return + } + + var ( + conference *Conference + call *Call + err error + ) + + switch evt.Type.Type { + case event.ToDeviceCallInvite.Type: + invite := evt.Content.AsCallInvite() + if invite == nil { + evtLogger.Error("failed to parse invite") + return + } + + if conference, err = f.GetOrCreateConference(invite.ConfID, true); err != nil { + evtLogger.WithError(err).WithFields(logrus.Fields{ + "conf_id": invite.ConfID, + }).Error("failed to create conf") + + return + } + + if err := conference.RemoveOldCallsByDeviceAndSessionIDs(invite.DeviceID, invite.SenderSessionID); err != nil { + evtLogger.WithError(err).Error("error removing old calls - ignoring call") + return + } + + if call, err = conference.GetCall(invite.CallID, true); err != nil || call == nil { + evtLogger.WithError(err).Error("failed to create call") + return + } + + call.InitWithInvite(evt, f.client) + call.OnInvite(invite) + case event.ToDeviceCallCandidates.Type: + candidates := evt.Content.AsCallCandidates() + if call, err = f.GetCall(candidates.ConfID, candidates.CallID); err != nil { + return + } + + call.OnCandidates(candidates) + case event.ToDeviceCallSelectAnswer.Type: + selectAnswer := evt.Content.AsCallSelectAnswer() + if call, err = f.GetCall(selectAnswer.ConfID, selectAnswer.CallID); err != nil { + return + } + + call.OnSelectAnswer(selectAnswer) + case event.ToDeviceCallHangup.Type: + hangup := evt.Content.AsCallHangup() + if call, err = f.GetCall(hangup.ConfID, hangup.CallID); err != nil { + return + } + + call.OnHangup() + // Events that we **should not** receive! + case event.ToDeviceCallNegotiate.Type: + evtLogger.Warn("ignoring event as it should be handled over DC") + case event.ToDeviceCallReject.Type: + case event.ToDeviceCallAnswer.Type: + evtLogger.Warn("ignoring event as we are always the ones answering") + default: + evtLogger.Warnf("ignoring unexpected event: %s", evt.Type.Type) + } +} From ba2fa4bd7e2a49f3b859403be3254bf1a9dbac4d Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 18 Nov 2022 22:56:14 +0100 Subject: [PATCH 03/62] refactor: define a new skeleton for the project This is a WIP state that does not work as expected! --- src/call.go | 565 ----------------------------------- src/conference.go | 114 ------- src/conference/conference.go | 147 +++++++++ src/matrix.go | 3 +- src/metadata.go | 104 ------- src/peer/channel.go | 57 ++++ src/peer/id.go | 8 + src/peer/peer.go | 141 +++++++++ src/peer/webrtc.go | 178 +++++++++++ src/sfu.go | 208 ++++++++----- 10 files changed, 665 insertions(+), 860 deletions(-) delete mode 100644 src/call.go delete mode 100644 src/conference.go create mode 100644 src/conference/conference.go delete mode 100644 src/metadata.go create mode 100644 src/peer/channel.go create mode 100644 src/peer/id.go create mode 100644 src/peer/peer.go create mode 100644 src/peer/webrtc.go diff --git a/src/call.go b/src/call.go deleted file mode 100644 index e92c853..0000000 --- a/src/call.go +++ /dev/null @@ -1,565 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "encoding/json" - "sync" - "time" - - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type Call struct { - PeerConnection *webrtc.PeerConnection - - CallID string - UserID id.UserID - DeviceID id.DeviceID - LocalSessionID id.SessionID - RemoteSessionID id.SessionID - - Publishers []*Publisher - Subscribers []*Subscriber - - mutex sync.RWMutex - logger *logrus.Entry - client *mautrix.Client - conf *Conference - - dataChannel *webrtc.DataChannel - lastKeepAliveTimestamp time.Time - sentEndOfCandidates bool -} - -func NewCall(callID string, conf *Conference) *Call { - call := new(Call) - - call.CallID = callID - call.conf = conf - - return call -} - -func (c *Call) InitWithInvite(evt *event.Event, client *mautrix.Client) { - invite := evt.Content.AsCallInvite() - - c.UserID = evt.Sender - c.DeviceID = invite.DeviceID - // XXX: What if an SFU gets restarted? - c.LocalSessionID = LocalSessionID - c.RemoteSessionID = invite.SenderSessionID - c.client = client - c.logger = logrus.WithFields(logrus.Fields{ - "user_id": evt.Sender, - "conf_id": invite.ConfID, - }) -} - -func (c *Call) onDCSelect(start []event.SFUTrackDescription) { - if len(start) == 0 { - return - } - - for _, trackDesc := range start { - trackLogger := c.logger.WithFields(logrus.Fields{ - "track_id": trackDesc.TrackID, - "stream_id": trackDesc.StreamID, - }) - - trackLogger.Info("selecting track") - - for _, publisher := range c.conf.GetPublishers() { - if publisher.Matches(trackDesc) { - publisher.Subscribe(c) - } - } - } -} - -func (c *Call) onDCPublish(sdp string) { - c.logger.Info("received DC publish") - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } - - offer, err := c.PeerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: offer.SDP, - }) -} - -func (c *Call) onDCUnpublish(stop []event.SFUTrackDescription, sdp string) { - for _, trackDesc := range stop { - for _, publisher := range c.Publishers { - if publisher.Matches(trackDesc) { - publisher.Stop() - } - } - } - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } - - offer, err := c.PeerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: offer.SDP, - }) -} - -func (c *Call) onDCAnswer(sdp string) { - c.logger.Info("received DC answer") - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } -} - -func (c *Call) onDCAlive() { - c.lastKeepAliveTimestamp = time.Now() -} - -func (c *Call) onDCMetadata() { - c.logger.Info("received DC metadata") - - c.conf.SendUpdatedMetadataFromPeer(c.CallID) -} - -func (c *Call) dataChannelHandler(channel *webrtc.DataChannel) { - c.dataChannel = channel - - channel.OnOpen(func() { - c.SendDataChannelMessage(event.SFUMessage{Op: event.SFUOperationMetadata}) - }) - - channel.OnError(func(err error) { - logrus.Fatalf("%s | DC error: %s", c.CallID, err) - }) - - channel.OnMessage(func(marshaledMsg webrtc.DataChannelMessage) { - if !marshaledMsg.IsString { - c.logger.WithField("msg", marshaledMsg).Warn("inbound message is not string - ignoring") - return - } - - msg := &event.SFUMessage{} - if err := json.Unmarshal(marshaledMsg.Data, msg); err != nil { - c.logger.WithField("msg", marshaledMsg).WithError(err).Error("failed to unmarshal - ignoring") - return - } - - if msg.Metadata != nil { - c.conf.Metadata.Update(c.DeviceID, msg.Metadata) - } - - switch msg.Op { - case event.SFUOperationSelect: - c.onDCSelect(msg.Start) - case event.SFUOperationPublish: - c.onDCPublish(msg.SDP) - case event.SFUOperationUnpublish: - c.onDCUnpublish(msg.Stop, msg.SDP) - case event.SFUOperationAnswer: - c.onDCAnswer(msg.SDP) - case event.SFUOperationAlive: - c.onDCAlive() - case event.SFUOperationMetadata: - c.onDCMetadata() - - default: - c.logger.WithField("op", msg.Op).Warn("Unknown operation - ignoring") - // TODO: hook up msg.Stop to unsubscribe from tracks - // TODO: hook cascade back up. - // As we're not an AS, we'd rely on the client - // to send us a "connect" op to tell us how to - // connect to another focus in order to select - // its streams. - } - }) -} - -func (c *Call) negotiationNeededHandler() { - offer, err := c.PeerConnection.CreateOffer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create offer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationOffer, - SDP: offer.SDP, - }) -} - -func (c *Call) iceCandidateHandler(candidate *webrtc.ICECandidate) { - if candidate == nil { - return - } - - jsonCandidate := candidate.ToJSON() - - candidateEvtContent := &event.Content{ - Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Candidates: []event.CallCandidate{{ - Candidate: jsonCandidate.Candidate, - SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), - SDPMID: *jsonCandidate.SDPMid, - }}, - }, - } - c.sendToDevice(event.CallCandidates, candidateEvtContent) -} - -func (c *Call) trackHandler(trackRemote *webrtc.TrackRemote) { - NewPublisher(trackRemote, c) - - go c.conf.SendUpdatedMetadataFromPeer(c.CallID) -} - -func (c *Call) iceConnectionStateHandler(state webrtc.ICEConnectionState) { - if state == webrtc.ICEConnectionStateCompleted || state == webrtc.ICEConnectionStateConnected { - c.lastKeepAliveTimestamp = time.Now() - go c.CheckKeepAliveTimestamp() - - if !c.sentEndOfCandidates { - candidateEvtContent := &event.Content{ - Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Candidates: []event.CallCandidate{{Candidate: ""}}, - }, - } - c.sendToDevice(event.CallCandidates, candidateEvtContent) - c.sentEndOfCandidates = true - } - } -} - -func (c *Call) OnInvite(content *event.CallInviteEventContent) { - c.conf.Metadata.Update(c.DeviceID, content.SDPStreamMetadata) - offer := content.Offer - - peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - if err != nil { - logrus.WithError(err).Error("failed to create new peer connection") - } - - c.PeerConnection = peerConnection - - peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - c.trackHandler(track) - }) - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { - c.dataChannelHandler(d) - }) - peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { - c.iceCandidateHandler(candidate) - }) - peerConnection.OnNegotiationNeeded(func() { - c.negotiationNeededHandler() - }) - peerConnection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { - c.iceConnectionStateHandler(state) - }) - - err = peerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: offer.SDP, - }) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set remote description - ignoring") - return - } - - answer, err := peerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - // TODO: trickle ICE for fast conn setup, rather than block here - gatherComplete := webrtc.GatheringCompletePromise(peerConnection) - - if err = peerConnection.SetLocalDescription(answer); err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - <-gatherComplete - - answerEvtContent := &event.Content{ - Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Answer: event.CallData{ - Type: "answer", - SDP: peerConnection.LocalDescription().SDP, - }, - SDPStreamMetadata: c.conf.Metadata.GetForDevice(c.DeviceID), - }, - } - c.sendToDevice(event.CallAnswer, answerEvtContent) -} - -func (c *Call) OnSelectAnswer(content *event.CallSelectAnswerEventContent) { - selectedPartyID := content.SelectedPartyID - if selectedPartyID != string(c.client.DeviceID) { - c.logger.WithField("selected_party_id", selectedPartyID).Warn("call was answered on a different device") - c.Terminate() - } -} - -func (c *Call) OnHangup() { - c.Terminate() -} - -func (c *Call) OnCandidates(content *event.CallCandidatesEventContent) { - for _, candidate := range content.Candidates { - sdpMLineIndex := uint16(candidate.SDPMLineIndex) - ice := webrtc.ICECandidateInit{ - Candidate: candidate.Candidate, - SDPMid: &candidate.SDPMID, - SDPMLineIndex: &sdpMLineIndex, - UsernameFragment: new(string), - } - - if err := c.PeerConnection.AddICECandidate(ice); err != nil { - c.logger.WithField("content", content).WithError(err).Error("failed to add ICE candidate") - } - } -} - -func (c *Call) Terminate() { - c.logger.Info("terminating call") - - if err := c.PeerConnection.Close(); err != nil { - c.logger.WithError(err).Error("error closing peer connection") - } - - c.conf.mutex.Lock() - delete(c.conf.Calls, c.CallID) - c.conf.mutex.Unlock() - - for _, publisher := range c.Publishers { - publisher.Stop() - } - - for _, subscriber := range c.Subscribers { - subscriber.Unsubscribe() - } - - c.conf.SendUpdatedMetadataFromPeer(c.CallID) -} - -func (c *Call) Hangup(reason event.CallHangupReason) { - hangupEvtContent := &event.Content{ - Parsed: event.CallHangupEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Reason: reason, - }, - } - c.sendToDevice(event.CallHangup, hangupEvtContent) - c.Terminate() -} - -func (c *Call) sendToDevice(callType event.Type, content *event.Content) { - evtLogger := c.logger.WithFields(logrus.Fields{ - "type": callType.Type, - }) - - if callType.Type != event.CallCandidates.Type { - evtLogger.Info("sending to device") - } - - toDevice := &mautrix.ReqSendToDevice{ - Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - c.UserID: { - c.DeviceID: content, - }, - }, - } - - // TODO: E2EE - // TODO: to-device reliability - if _, err := c.client.SendToDevice(callType, toDevice); err != nil { - evtLogger.WithField("content", content).WithError(err).Error("error sending to-device") - } -} - -func (c *Call) SendDataChannelMessage(msg event.SFUMessage) { - if c.dataChannel == nil { - return - } - - evtLogger := c.logger.WithFields(logrus.Fields{ - "op": msg.Op, - }) - - if msg.Metadata == nil { - msg.Metadata = c.conf.Metadata.GetForDevice(c.DeviceID) - if msg.Op == event.SFUOperationMetadata && len(msg.Metadata) == 0 { - return - } - } - - marshaled, err := json.Marshal(msg) - if err != nil { - evtLogger.WithField("msg", msg).WithError(err).Error("failed to marshal - ignoring") - return - } - - err = c.dataChannel.SendText(string(marshaled)) - if err != nil { - evtLogger.WithField("msg", msg).WithError(err).Error("failed to send message over DC") - } - - evtLogger.Info("sent message over DC") -} - -func (c *Call) CheckKeepAliveTimestamp() { - timeout := time.Second * time.Duration(c.conf.Config.KeepAliveTimeout) - for range time.Tick(timeout) { - if c.lastKeepAliveTimestamp.Add(timeout).Before(time.Now()) { - if c.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateClosed { - c.logger.WithField("timeout", timeout).Warn("did not get keep-alive message") - c.Hangup(event.CallHangupKeepAliveTimeout) - } - - break - } - } -} - -func (c *Call) RemoveSubscriber(toDelete *Subscriber) bool { - removed := false - newSubscribers := []*Subscriber{} - - c.mutex.Lock() - for _, subscriber := range c.Subscribers { - if subscriber != toDelete { - removed = true - } else { - newSubscribers = append(newSubscribers, subscriber) - } - } - - c.Subscribers = newSubscribers - c.mutex.Unlock() - - return removed -} - -func (c *Call) RemovePublisher(toDelete *Publisher) bool { - removed := false - newPublishers := []*Publisher{} - - c.mutex.Lock() - for _, publisher := range c.Publishers { - if publisher == toDelete { - removed = true - } else { - newPublishers = append(newPublishers, publisher) - } - } - - c.Publishers = newPublishers - c.mutex.Unlock() - - return removed -} diff --git a/src/conference.go b/src/conference.go deleted file mode 100644 index 110fd07..0000000 --- a/src/conference.go +++ /dev/null @@ -1,114 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "sync" - - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ( - ErrNoSuchCall = errors.New("no such call") - ErrFoundExistingCallWithSameSessionAndDeviceID = errors.New("found existing call with equal DeviceID and SessionID") -) - -// Configuration for the group conferences (calls). -type CallConfig struct { - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. - KeepAliveTimeout int -} - -type Conference struct { - ConfID string - Calls map[string]*Call // By callID - Config *CallConfig // TODO: this must be protected by a mutex actually - - mutex sync.RWMutex - logger *logrus.Entry - Metadata *Metadata -} - -func NewConference(confID string, config *CallConfig) *Conference { - conference := new(Conference) - conference.Config = config - conference.ConfID = confID - conference.Calls = make(map[string]*Call) - conference.Metadata = NewMetadata(conference) - conference.logger = logrus.WithFields(logrus.Fields{ - "conf_id": confID, - }) - - return conference -} - -func (c *Conference) GetCall(callID string, create bool) (*Call, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - call := c.Calls[callID] - - if call == nil { - if create { - call = NewCall(callID, c) - c.Calls[callID] = call - } else { - return nil, ErrNoSuchCall - } - } - - return call, nil -} - -func (c *Conference) RemoveOldCallsByDeviceAndSessionIDs(deviceID id.DeviceID, sessionID id.SessionID) error { - var err error - - for _, call := range c.Calls { - if call.DeviceID == deviceID { - if call.RemoteSessionID == sessionID { - err = ErrFoundExistingCallWithSameSessionAndDeviceID - } else { - call.Terminate() - } - } - } - - return err -} - -func (c *Conference) SendUpdatedMetadataFromPeer(callID string) { - for _, call := range c.Calls { - if call.CallID != callID { - call.SendDataChannelMessage(event.SFUMessage{Op: event.SFUOperationMetadata}) - } - } -} - -func (c *Conference) GetPublishers() []*Publisher { - publishers := []*Publisher{} - - c.mutex.RLock() - for _, call := range c.Calls { - publishers = append(publishers, call.Publishers...) - } - c.mutex.RUnlock() - - return publishers -} diff --git a/src/conference/conference.go b/src/conference/conference.go new file mode 100644 index 0000000..290e7f9 --- /dev/null +++ b/src/conference/conference.go @@ -0,0 +1,147 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conference + +import ( + "github.com/matrix-org/waterfall/src/peer" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Configuration for the group conferences (calls). +type CallConfig struct { + // Keep-alive timeout for WebRTC connections. If no keep-alive has been received + // from the client for this duration, the connection is considered dead. + KeepAliveTimeout int +} + +type Participant struct { + Peer *peer.Peer + Data *ParticipantData +} + +type ParticipantData struct { + RemoteSessionID id.SessionID + StreamMetadata event.CallSDPStreamMetadata +} + +type Conference struct { + conferenceID string + config *CallConfig + participants map[peer.ID]*Participant + participantsChannel peer.MessageChannel + logger *logrus.Entry +} + +func NewConference(confID string, config *CallConfig) *Conference { + conference := new(Conference) + conference.config = config + conference.conferenceID = confID + conference.participants = make(map[peer.ID]*Participant) + conference.participantsChannel = make(peer.MessageChannel) + conference.logger = logrus.WithFields(logrus.Fields{ + "conf_id": confID, + }) + return conference +} + +// New participant tries to join the conference. +func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.CallInviteEventContent) { + // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, + // any existing calls from this device in this call should be terminated. + // TODO: Implement this. + /* + for _, participant := range c.participants { + if participant.data.DeviceID == inviteEvent.DeviceID { + if participant.data.RemoteSessionID == inviteEvent.SenderSessionID { + c.logger.WithFields(logrus.Fields{ + "device_id": inviteEvent.DeviceID, + "session_id": inviteEvent.SenderSessionID, + }).Errorf("Found existing participant with equal DeviceID and SessionID") + return + } else { + participant.Terminate() + delete(c.participants, participant.data.UserID) + } + } + } + */ + + peer, _, err := peer.NewPeer(participantID, c.conferenceID, inviteEvent.Offer.SDP, c.participantsChannel) + if err != nil { + c.logger.WithError(err).Errorf("Failed to create new peer") + return + } + + participantData := &ParticipantData{ + RemoteSessionID: inviteEvent.SenderSessionID, + StreamMetadata: inviteEvent.SDPStreamMetadata, + } + + c.participants[participantID] = &Participant{Peer: peer, Data: participantData} + + // TODO: Send the SDP answer back to the participant's device. +} + +func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) { + if participant := c.getParticipant(peerID); participant != nil { + // Convert the candidates to the WebRTC format. + candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates)) + for i, candidate := range candidatesEvent.Candidates { + SDPMLineIndex := uint16(candidate.SDPMLineIndex) + candidates[i] = webrtc.ICECandidateInit{ + Candidate: candidate.Candidate, + SDPMid: &candidate.SDPMID, + SDPMLineIndex: &SDPMLineIndex, + } + } + + participant.Peer.AddICECandidates(candidates) + } +} + +func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) { + if participant := c.getParticipant(peerID); participant != nil { + if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() { + c.logger.WithFields(logrus.Fields{ + "device_id": selectAnswerEvent.SelectedPartyID, + }).Errorf("Call was answered on a different device, kicking this peer") + participant.Peer.Terminate() + } + } +} + +func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) { + if participant := c.getParticipant(peerID); participant != nil { + participant.Peer.Terminate() + } +} + +func (c *Conference) getParticipant(peerID peer.ID) *Participant { + participant, ok := c.participants[peerID] + if !ok { + c.logger.WithFields(logrus.Fields{ + "user_id": peerID.UserID, + "device_id": peerID.DeviceID, + }).Errorf("Failed to find participant") + return nil + } + + return participant +} diff --git a/src/matrix.go b/src/matrix.go index 0e67a65..06769f1 100644 --- a/src/matrix.go +++ b/src/matrix.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "github.com/matrix-org/waterfall/src/conference" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" ) @@ -45,7 +46,7 @@ func RunServer(config *Config) { focus := NewSFU( client, - &CallConfig{KeepAliveTimeout: config.KeepAliveTimeout}, + &conference.CallConfig{KeepAliveTimeout: config.KeepAliveTimeout}, ) syncer, ok := client.Syncer.(*mautrix.DefaultSyncer) diff --git a/src/metadata.go b/src/metadata.go deleted file mode 100644 index 3ae0ed6..0000000 --- a/src/metadata.go +++ /dev/null @@ -1,104 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "sync" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type Metadata struct { - mutex sync.RWMutex - data event.CallSDPStreamMetadata - conference *Conference -} - -func NewMetadata(conference *Conference) *Metadata { - metadata := new(Metadata) - - metadata.data = make(event.CallSDPStreamMetadata) - metadata.conference = conference - - return metadata -} - -func (m *Metadata) Update(deviceID id.DeviceID, metadata event.CallSDPStreamMetadata) { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Update existing and add new - for streamID, info := range metadata { - m.data[streamID] = info - } - // Remove removed - for streamID, info := range m.data { - _, exists := metadata[streamID] - if info.DeviceID == deviceID && !exists { - delete(m.data, streamID) - } - } -} - -func (m *Metadata) RemoveByDevice(deviceID id.DeviceID) { - m.mutex.Lock() - defer m.mutex.Unlock() - - for streamID, info := range m.data { - if info.DeviceID == deviceID { - delete(m.data, streamID) - } - } -} - -// Get metadata to send to deviceID. This will not include the device's own -// metadata and metadata which includes tracks which we have not received yet. -func (m *Metadata) GetForDevice(deviceID id.DeviceID) event.CallSDPStreamMetadata { - metadata := make(event.CallSDPStreamMetadata) - - m.mutex.RLock() - defer m.mutex.RUnlock() - - for _, publisher := range m.conference.GetPublishers() { - if deviceID == publisher.Call.DeviceID { - continue - } - - streamID := publisher.Track.StreamID() - trackID := publisher.Track.ID() - - info, exists := metadata[streamID] - if exists { - info.Tracks[publisher.Track.ID()] = event.CallSDPStreamMetadataTrack{} - metadata[streamID] = info - } else { - metadata[streamID] = event.CallSDPStreamMetadataObject{ - UserID: publisher.Call.UserID, - DeviceID: publisher.Call.DeviceID, - Purpose: m.data[streamID].Purpose, - AudioMuted: m.data[streamID].AudioMuted, - VideoMuted: m.data[streamID].VideoMuted, - Tracks: event.CallSDPStreamMetadataTracks{ - trackID: {}, - }, - } - } - } - - return metadata -} diff --git a/src/peer/channel.go b/src/peer/channel.go new file mode 100644 index 0000000..0f5e127 --- /dev/null +++ b/src/peer/channel.go @@ -0,0 +1,57 @@ +package peer + +import ( + "github.com/pion/webrtc/v3" +) + +type MessageChannel chan interface{} + +type PeerJoinedTheCall struct { + Sender ID +} + +type PeerLeftTheCall struct { + Sender ID +} + +type NewTrackPublished struct { + Sender ID + Track *webrtc.TrackLocalStaticRTP +} + +type PublishedTrackFailed struct { + Sender ID + Track *webrtc.TrackLocalStaticRTP +} + +type NewICECandidate struct { + Sender ID + Candidate *webrtc.ICECandidate +} + +type ICEGatheringComplete struct { + Sender ID +} + +type NewOffer struct { + Sender ID + Offer *webrtc.SessionDescription +} + +type DataChannelOpened struct { + Sender ID +} + +type DataChannelClosed struct { + Sender ID +} + +type DataChannelMessage struct { + Sender ID + Message string +} + +type DataChannelError struct { + Sender ID + Err error +} diff --git a/src/peer/id.go b/src/peer/id.go new file mode 100644 index 0000000..e7b4697 --- /dev/null +++ b/src/peer/id.go @@ -0,0 +1,8 @@ +package peer + +import "maunium.net/go/mautrix/id" + +type ID struct { + UserID id.UserID + DeviceID id.DeviceID +} diff --git a/src/peer/peer.go b/src/peer/peer.go new file mode 100644 index 0000000..e9a3b8b --- /dev/null +++ b/src/peer/peer.go @@ -0,0 +1,141 @@ +package peer + +import ( + "errors" + "sync" + + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" +) + +var ( + ErrCantCreatePeerConnection = errors.New("can't create peer connection") + ErrCantSetRemoteDecsription = errors.New("can't set remote description") + ErrCantCreateAnswer = errors.New("can't create answer") + ErrCantSetLocalDescription = errors.New("can't set local description") + ErrCantCreateLocalDescription = errors.New("can't create local description") + ErrDataChannelNotAvailable = errors.New("data channel is not available") + ErrCantSubscribeToTrack = errors.New("can't subscribe to track") +) + +type Peer struct { + id ID + logger *logrus.Entry + notify chan<- interface{} + peerConnection *webrtc.PeerConnection + + dataChannelMutex sync.Mutex + dataChannel *webrtc.DataChannel +} + +func NewPeer( + info ID, + conferenceId string, + sdpOffer string, + notify chan<- interface{}, +) (*Peer, *webrtc.SessionDescription, error) { + logger := logrus.WithFields(logrus.Fields{ + "user_id": info.UserID, + "device_id": info.DeviceID, + "conf_id": conferenceId, + }) + + peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + logger.WithError(err).Error("failed to create peer connection") + return nil, nil, ErrCantCreatePeerConnection + } + + peer := &Peer{ + id: info, + logger: logger, + notify: notify, + peerConnection: peerConnection, + } + + peerConnection.OnTrack(peer.onRtpTrackReceived) + peerConnection.OnDataChannel(peer.onDataChannelReady) + peerConnection.OnICECandidate(peer.onICECandidateGathered) + peerConnection.OnNegotiationNeeded(peer.onNegotiationNeeded) + peerConnection.OnICEConnectionStateChange(peer.onICEConnectionStateChanged) + peerConnection.OnICEGatheringStateChange(peer.onICEGatheringStateChanged) + peerConnection.OnConnectionStateChange(peer.onConnectionStateChanged) + peerConnection.OnSignalingStateChange(peer.onSignalingStateChanged) + + err = peerConnection.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + }) + if err != nil { + logger.WithError(err).Error("failed to set remote description") + peerConnection.Close() + return nil, nil, ErrCantSetRemoteDecsription + } + + answer, err := peerConnection.CreateAnswer(nil) + if err != nil { + logger.WithError(err).Error("failed to create answer") + peerConnection.Close() + return nil, nil, ErrCantCreateAnswer + } + + if err := peerConnection.SetLocalDescription(answer); err != nil { + logger.WithError(err).Error("failed to set local description") + peerConnection.Close() + return nil, nil, ErrCantSetLocalDescription + } + + // TODO: Do we really need to call `webrtc.GatheringCompletePromise` + // as in the previous version of the `waterfall` here? + + sdpAnswer := peerConnection.LocalDescription() + if sdpAnswer == nil { + logger.WithError(err).Error("could not generate a local description") + peerConnection.Close() + return nil, nil, ErrCantCreateLocalDescription + } + + return peer, sdpAnswer, nil +} + +func (p *Peer) Terminate() { + if err := p.peerConnection.Close(); err != nil { + p.logger.WithError(err).Error("failed to close peer connection") + } + + p.notify <- PeerLeftTheCall{Sender: p.id} +} + +func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { + for _, candidate := range candidates { + if err := p.peerConnection.AddICECandidate(candidate); err != nil { + p.logger.WithError(err).Error("failed to add ICE candidate") + } + } +} + +func (p *Peer) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { + _, err := p.peerConnection.AddTrack(track) + if err != nil { + p.logger.WithError(err).Error("failed to add track") + return ErrCantSubscribeToTrack + } + + return nil +} + +func (p *Peer) SendOverDataChannel(json string) error { + p.dataChannelMutex.Lock() + defer p.dataChannelMutex.Unlock() + + if p.dataChannel == nil { + p.logger.Error("can't send data over data channel: data channel is not ready") + return ErrDataChannelNotAvailable + } + + if err := p.dataChannel.SendText(json); err != nil { + p.logger.WithError(err).Error("failed to send data over data channel") + } + + return nil +} diff --git a/src/peer/webrtc.go b/src/peer/webrtc.go new file mode 100644 index 0000000..4d97a80 --- /dev/null +++ b/src/peer/webrtc.go @@ -0,0 +1,178 @@ +package peer + +import ( + "errors" + "io" + "time" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" +) + +// A callback that is called once we receive first RTP packets from a track, i.e. +// we call this function each time a new track is received. +func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval. + // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI + // when a viewer requests it. + // + // TODO: Add RTCP handling based on the PR from @SimonBrandner. + go func() { + ticker := time.NewTicker(time.Millisecond * 500) // every 500ms + for range ticker.C { + rtcp := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}} + if rtcpSendErr := p.peerConnection.WriteRTCP(rtcp); rtcpSendErr != nil { + p.logger.Errorf("Failed to send RTCP PLI: %v", rtcpSendErr) + } + } + }() + + // Create a local track, all our SFU clients that are subscribed to this + // peer (publisher) wil be fed via this track. + localTrack, err := webrtc.NewTrackLocalStaticRTP( + remoteTrack.Codec().RTPCodecCapability, + remoteTrack.ID(), + remoteTrack.StreamID(), + ) + if err != nil { + p.logger.WithError(err).Error("failed to create local track") + return + } + + // Notify others that our track has just been published. + p.notify <- NewTrackPublished{Sender: p.id, Track: localTrack} + + // Start forwarding the data from the remote track to the local track, + // so that everyone who is subscribed to this track will receive the data. + go func() { + rtpBuf := make([]byte, 1400) + + for { + index, _, readErr := remoteTrack.Read(rtpBuf) + // TODO: inform the conference that this publisher's track is not available anymore. + if readErr != nil { + if readErr == io.EOF { // finished, no more data, no error, inform others + p.logger.Info("remote track closed") + } else { // finished, no more data, but with error, inform others + p.logger.WithError(readErr).Error("failed to read from remote track") + } + p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + } + + // ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet. + if _, err = localTrack.Write(rtpBuf[:index]); err != nil && !errors.Is(err, io.ErrClosedPipe) { + p.logger.WithError(err).Error("failed to write to local track") + p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + } + } + }() +} + +// A callback that is called once we receive an ICE candidate for this peer connection. +func (p *Peer) onICECandidateGathered(candidate *webrtc.ICECandidate) { + if candidate == nil { + p.logger.Info("ICE candidate gathering finished") + return + } + + p.logger.WithField("candidate", candidate).Debug("ICE candidate gathered") + p.notify <- NewICECandidate{Sender: p.id, Candidate: candidate} +} + +// A callback that is called once we receive an ICE connection state change for this peer connection. +func (p *Peer) onNegotiationNeeded() { + p.logger.Debug("negotiation needed") + offer, err := p.peerConnection.CreateOffer(nil) + if err != nil { + p.logger.WithError(err).Error("failed to create offer") + return + } + + if err := p.peerConnection.SetLocalDescription(offer); err != nil { + p.logger.WithError(err).Error("failed to set local description") + return + } + + p.notify <- NewOffer{Sender: p.id, Offer: &offer} +} + +// A callback that is called once we receive an ICE connection state change for this peer connection. +func (p *Peer) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { + p.logger.WithField("state", state).Debug("ICE connection state changed") + + switch state { + case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateDisconnected: + // TODO: We may want to treat it as an opportunity for the ICE restart instead. + // TODO: Ask Simon if we should do it here as in the previous implementation of the + // `waterfall` or the way I did it in this new implementation. + // p.notify <- PeerLeftTheCall{sender: p.data} + case webrtc.ICEConnectionStateCompleted, webrtc.ICEConnectionStateConnected: + // TODO: Start keep-alive timer over the data channel to check the connecitons that hanged. + // TODO: Ask Simon if we should do it here as in the previous implementation of the + // `waterfall` or the way I did it in this new implementation. + // p.notify <- PeerJoinedTheCall{sender: p.data} + p.notify <- ICEGatheringComplete{Sender: p.id} + } +} + +func (p *Peer) onICEGatheringStateChanged(state webrtc.ICEGathererState) { + p.logger.WithField("state", state).Debug("ICE gathering state changed") + + if state == webrtc.ICEGathererStateComplete { + p.notify <- ICEGatheringComplete{Sender: p.id} + } +} + +func (p *Peer) onSignalingStateChanged(state webrtc.SignalingState) { + p.logger.WithField("state", state).Debug("signaling state changed") +} + +func (p *Peer) onConnectionStateChanged(state webrtc.PeerConnectionState) { + p.logger.WithField("state", state).Debug("connection state changed") + + switch state { + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: + p.notify <- PeerLeftTheCall{Sender: p.id} + case webrtc.PeerConnectionStateConnected: + p.notify <- PeerJoinedTheCall{Sender: p.id} + } +} + +// A callback that is called once the data channel is ready to be used. +func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { + p.dataChannelMutex.Lock() + defer p.dataChannelMutex.Unlock() + + if p.dataChannel != nil { + p.logger.Error("data channel already exists") + p.dataChannel.Close() + return + } + + p.dataChannel = dc + p.logger.WithField("label", dc.Label()).Info("data channel ready") + + dc.OnOpen(func() { + p.logger.Info("data channel opened") + p.notify <- DataChannelOpened{Sender: p.id} + }) + + dc.OnClose(func() { + p.logger.Info("data channel closed") + p.notify <- DataChannelClosed{Sender: p.id} + }) + + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.logger.WithField("message", msg).Debug("data channel message received") + if msg.IsString { + p.notify <- DataChannelMessage{Sender: p.id, Message: string(msg.Data)} + } else { + p.logger.Warn("data channel message is not a string, ignoring") + } + }) + + dc.OnError(func(err error) { + p.logger.WithError(err).Error("data channel error") + p.notify <- DataChannelError{Sender: p.id, Err: err} + }) +} diff --git a/src/sfu.go b/src/sfu.go index c183b23..08c6811 100644 --- a/src/sfu.go +++ b/src/sfu.go @@ -19,9 +19,12 @@ package main import ( "errors" + "github.com/matrix-org/waterfall/src/conference" + "github.com/matrix-org/waterfall/src/peer" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) var ErrNoSuchConference = errors.New("no such conference") @@ -34,140 +37,193 @@ type SFU struct { // Matrix client. client *mautrix.Client // All calls currently forwarded by this SFU. - conferences map[string]*Conference - // Structured logging for the SFU. - logger *logrus.Entry + conferences map[string]*conference.Conference // Configuration for the calls. - config *CallConfig + config *conference.CallConfig } // Creates a new instance of the SFU with the given configuration. -func NewSFU(client *mautrix.Client, config *CallConfig) *SFU { +func NewSFU(client *mautrix.Client, config *conference.CallConfig) *SFU { return &SFU{ client: client, - conferences: make(map[string]*Conference), - logger: logrus.WithField("module", "sfu"), + conferences: make(map[string]*conference.Conference), config: config, } } -// Returns a conference by its `id`, or creates a new one if it doesn't exist yet. -func (f *SFU) GetOrCreateConference(confID string, create bool) (*Conference, error) { - if conference := f.conferences[confID]; conference != nil { - return conference, nil - } - - if create { - f.logger.Printf("creating new conference %s", confID) - conference := NewConference(confID, f.config) - f.conferences[confID] = conference - return conference, nil - } - - return nil, ErrNoSuchConference -} - -func (f *SFU) GetCall(confID string, callID string) (*Call, error) { - var ( - conf *Conference - call *Call - err error - ) - - if conf, err = f.GetOrCreateConference(confID, false); err != nil || conf == nil { - f.logger.Printf("failed to get conf %s: %s", confID, err) - return nil, err - } - - if call, err = conf.GetCall(callID, false); err != nil || call == nil { - f.logger.Printf("failed to get call %s: %s", callID, err) - return nil, err - } - - return call, nil -} - // Handles To-Device events that the SFU receives from clients. +// +//nolint:funlen func (f *SFU) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { // We only care about to-device events. if evt.Type.Class != event.ToDeviceEventType { - f.logger.Warn("ignoring a not to-device event") + logrus.Warn("ignoring a not to-device event") return } - evtLogger := f.logger.WithFields(logrus.Fields{ + // TODO: Don't create logger again and again, it might be a bit expensive. + logger := logrus.WithFields(logrus.Fields{ "type": evt.Type.Type, "user_id": evt.Sender.String(), "conf_id": evt.Content.Raw["conf_id"], }) if evt.Content.Raw["dest_session_id"] != LocalSessionID { - evtLogger.WithField("dest_session_id", LocalSessionID).Warn("SessionID does not match our SessionID - ignoring") + logger.WithField("dest_session_id", LocalSessionID).Warn("SessionID does not match our SessionID - ignoring") return } - var ( - conference *Conference - call *Call - err error - ) - switch evt.Type.Type { + // Someone tries to participate in a call (join a call). case event.ToDeviceCallInvite.Type: invite := evt.Content.AsCallInvite() if invite == nil { - evtLogger.Error("failed to parse invite") + logger.Error("failed to parse invite") return } - if conference, err = f.GetOrCreateConference(invite.ConfID, true); err != nil { - evtLogger.WithError(err).WithFields(logrus.Fields{ - "conf_id": invite.ConfID, - }).Error("failed to create conf") + // If there is an invitation sent and the conf does not exist, create one. + if conf := f.conferences[invite.ConfID]; conf == nil { + logger.Infof("creating new conference %s", invite.ConfID) + f.conferences[invite.ConfID] = conference.NewConference(invite.ConfID, f.config) + } - return + peerID := peer.ID{ + UserID: evt.Sender, + DeviceID: invite.DeviceID, } - if err := conference.RemoveOldCallsByDeviceAndSessionIDs(invite.DeviceID, invite.SenderSessionID); err != nil { - evtLogger.WithError(err).Error("error removing old calls - ignoring call") + // Inform conference about incoming participant. + f.conferences[invite.ConfID].OnNewParticipant(peerID, invite) + + // Someone tries to send ICE candidates to the existing call. + case event.ToDeviceCallCandidates.Type: + candidates := evt.Content.AsCallCandidates() + if candidates == nil { + logger.Error("failed to parse candidates") return } - if call, err = conference.GetCall(invite.CallID, true); err != nil || call == nil { - evtLogger.WithError(err).Error("failed to create call") + conference := f.conferences[candidates.ConfID] + if conference == nil { + logger.Errorf("received candidates for unknown conference %s", candidates.ConfID) return } - call.InitWithInvite(evt, f.client) - call.OnInvite(invite) - case event.ToDeviceCallCandidates.Type: - candidates := evt.Content.AsCallCandidates() - if call, err = f.GetCall(candidates.ConfID, candidates.CallID); err != nil { - return + peerID := peer.ID{ + UserID: evt.Sender, + DeviceID: candidates.DeviceID, } - call.OnCandidates(candidates) + conference.OnCandidates(peerID, candidates) + + // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. case event.ToDeviceCallSelectAnswer.Type: selectAnswer := evt.Content.AsCallSelectAnswer() - if call, err = f.GetCall(selectAnswer.ConfID, selectAnswer.CallID); err != nil { + if selectAnswer == nil { + logger.Error("failed to parse select_answer") + return + } + + conference := f.conferences[selectAnswer.ConfID] + if conference == nil { + logger.Errorf("received select_answer for unknown conference %s", selectAnswer.ConfID) return } - call.OnSelectAnswer(selectAnswer) + peerID := peer.ID{ + UserID: evt.Sender, + DeviceID: selectAnswer.DeviceID, + } + + conference.OnSelectAnswer(peerID, selectAnswer) + + // Someone tries to inform us about leaving an existing call. case event.ToDeviceCallHangup.Type: hangup := evt.Content.AsCallHangup() - if call, err = f.GetCall(hangup.ConfID, hangup.CallID); err != nil { + if hangup == nil { + logger.Error("failed to parse hangup") + return + } + + conference := f.conferences[hangup.ConfID] + if conference == nil { + logger.Errorf("received hangup for unknown conference %s", hangup.ConfID) return } - call.OnHangup() + peerID := peer.ID{ + UserID: evt.Sender, + DeviceID: hangup.DeviceID, + } + + conference.OnHangup(peerID, hangup) + // Events that we **should not** receive! case event.ToDeviceCallNegotiate.Type: - evtLogger.Warn("ignoring event as it should be handled over DC") + logger.Warn("ignoring negotiate event that must be handled over the data channel") case event.ToDeviceCallReject.Type: + logger.Warn("ignoring reject event that must be handled over the data channel") case event.ToDeviceCallAnswer.Type: - evtLogger.Warn("ignoring event as we are always the ones answering") + logger.Warn("ignoring event as we are always the ones sending the SDP answer at the moment") default: - evtLogger.Warnf("ignoring unexpected event: %s", evt.Type.Type) + logger.Warnf("ignoring unexpected event: %s", evt.Type.Type) + } +} + +func (f *SFU) createSDPAnswerEvent( + conferenceID string, + destSessionID id.SessionID, + peerID peer.ID, + sdp string, + streamMetadata event.CallSDPStreamMetadata, +) *event.Content { + return &event.Content{ + Parsed: event.CallAnswerEventContent{ + BaseCallEventContent: createBaseEventContent(conferenceID, f.client.DeviceID, peerID.DeviceID, destSessionID), + Answer: event.CallData{ + Type: "answer", + SDP: sdp, + }, + SDPStreamMetadata: streamMetadata, + }, + } +} + +func createBaseEventContent( + conferenceID string, + sfuDeviceID id.DeviceID, + destDeviceID id.DeviceID, + destSessionID id.SessionID, +) event.BaseCallEventContent { + return event.BaseCallEventContent{ + CallID: conferenceID, + ConfID: conferenceID, + DeviceID: sfuDeviceID, + SenderSessionID: LocalSessionID, + DestSessionID: destSessionID, + PartyID: string(destDeviceID), + Version: event.CallVersion("1"), + } +} + +// Sends a to-device event to the given user. +func (f *SFU) sendToDevice(participantID peer.ID, ev *event.Event) { + // TODO: Don't create logger again and again, it might be a bit expensive. + logger := logrus.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + sendRequest := &mautrix.ReqSendToDevice{ + Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + participantID.UserID: { + participantID.DeviceID: &ev.Content, + }, + }, + } + + if _, err := f.client.SendToDevice(ev.Type, sendRequest); err != nil { + logger.Errorf("failed to send to-device event: %w", err) } } From 85aa12c55a4201a8e9461c12bef3c59ea9841874 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 22 Nov 2022 00:12:32 +0100 Subject: [PATCH 04/62] refactor: create a better structure for the sources --- src/{ => config}/config.go | 18 +++++++++--------- src/main.go | 6 ++++-- src/{ => signaling}/matrix.go | 7 ++++--- src/{sfu.go => signaling/signaling.go} | 22 +++++++++++----------- 4 files changed, 28 insertions(+), 25 deletions(-) rename src/{ => config}/config.go (82%) rename src/{ => signaling}/matrix.go (93%) rename src/{sfu.go => signaling/signaling.go} (90%) diff --git a/src/config.go b/src/config/config.go similarity index 82% rename from src/config.go rename to src/config/config.go index c76e093..691ace9 100644 --- a/src/config.go +++ b/src/config/config.go @@ -1,4 +1,4 @@ -package main +package config import ( "errors" @@ -27,14 +27,14 @@ type Config struct { // If the environment variable is not set, tries to load a config from the // provided path to the config file (YAML). Returns an error if the config could // not be loaded. -func loadConfig(path string) (*Config, error) { - config, err := loadConfigFromEnv() +func LoadConfig(path string) (*Config, error) { + config, err := LoadConfigFromEnv() if err != nil { if !errors.Is(err, ErrNoConfigEnvVar) { return nil, err } - return loadConfigFromPath(path) + return LoadConfigFromPath(path) } return config, nil @@ -45,17 +45,17 @@ var ErrNoConfigEnvVar = errors.New("environment variable not set or invalid") // Tries to load the config from environment variable (`CONFIG`). // Returns an error if not all environment variables are set. -func loadConfigFromEnv() (*Config, error) { +func LoadConfigFromEnv() (*Config, error) { configEnv := os.Getenv("CONFIG") if configEnv == "" { return nil, ErrNoConfigEnvVar } - return loadConfigFromString(configEnv) + return LoadConfigFromString(configEnv) } // Tries to load a config from the provided path. -func loadConfigFromPath(path string) (*Config, error) { +func LoadConfigFromPath(path string) (*Config, error) { logrus.WithField("path", path).Info("loading config") file, err := os.ReadFile(path) @@ -63,12 +63,12 @@ func loadConfigFromPath(path string) (*Config, error) { return nil, fmt.Errorf("failed to read file: %w", err) } - return loadConfigFromString(string(file)) + return LoadConfigFromString(string(file)) } // Load config from the provided string. // Returns an error if the string is not a valid YAML. -func loadConfigFromString(configString string) (*Config, error) { +func LoadConfigFromString(configString string) (*Config, error) { logrus.Info("loading config from string") var config Config diff --git a/src/main.go b/src/main.go index ec4bfea..320e6a0 100644 --- a/src/main.go +++ b/src/main.go @@ -22,6 +22,8 @@ import ( "os/signal" "syscall" + "github.com/matrix-org/waterfall/src/config" + "github.com/matrix-org/waterfall/src/signaling" "github.com/sirupsen/logrus" ) @@ -60,10 +62,10 @@ func main() { }() // Load the config file from the environment variable or path. - config, err := loadConfig(*configFilePath) + config, err := config.LoadConfig(*configFilePath) if err != nil { logrus.WithError(err).Fatal("could not load config") } - RunServer(config) + signaling.RunServer(config) } diff --git a/src/matrix.go b/src/signaling/matrix.go similarity index 93% rename from src/matrix.go rename to src/signaling/matrix.go index 06769f1..d2d43a2 100644 --- a/src/matrix.go +++ b/src/signaling/matrix.go @@ -14,10 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package signaling import ( "github.com/matrix-org/waterfall/src/conference" + "github.com/matrix-org/waterfall/src/config" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" ) @@ -26,7 +27,7 @@ const LocalSessionID = "sfu" // Starts the Matrix client and connects to the homeserver, // runs the SFU. Returns only when the sync with Matrix fails. -func RunServer(config *Config) { +func RunServer(config *config.Config) { client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) if err != nil { logrus.WithError(err).Fatal("Failed to create client") @@ -44,7 +45,7 @@ func RunServer(config *Config) { logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") client.DeviceID = whoami.DeviceID - focus := NewSFU( + focus := NewSignalingServer( client, &conference.CallConfig{KeepAliveTimeout: config.KeepAliveTimeout}, ) diff --git a/src/sfu.go b/src/signaling/signaling.go similarity index 90% rename from src/sfu.go rename to src/signaling/signaling.go index 08c6811..ab9ea6e 100644 --- a/src/sfu.go +++ b/src/signaling/signaling.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package main +package signaling import ( "errors" @@ -29,11 +29,11 @@ import ( var ErrNoSuchConference = errors.New("no such conference") -// The top-level state of the SFU. -// Note that in Matrix MSCs, the term "focus" is used to refer to the SFU. But since "focus" is a very -// generic name and only makes sense in a certain context, we use the term "SFU" instead to avoid confusion -// given that this particular part is just the SFU logic (and not the "focus" selection algorithm etc). -type SFU struct { +// The top-level state of the SignalingServer. +// Note that in Matrix MSCs, the term "focus" is used to refer to the SignalingServer. But since "focus" is a very +// generic name and only makes sense in a certain context, we use the term "SignalingServer" instead to avoid confusion +// given that this particular part is just the SignalingServer logic (and not the "focus" selection algorithm etc). +type SignalingServer struct { // Matrix client. client *mautrix.Client // All calls currently forwarded by this SFU. @@ -43,8 +43,8 @@ type SFU struct { } // Creates a new instance of the SFU with the given configuration. -func NewSFU(client *mautrix.Client, config *conference.CallConfig) *SFU { - return &SFU{ +func NewSignalingServer(client *mautrix.Client, config *conference.CallConfig) *SignalingServer { + return &SignalingServer{ client: client, conferences: make(map[string]*conference.Conference), config: config, @@ -54,7 +54,7 @@ func NewSFU(client *mautrix.Client, config *conference.CallConfig) *SFU { // Handles To-Device events that the SFU receives from clients. // //nolint:funlen -func (f *SFU) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { +func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { // We only care about to-device events. if evt.Type.Class != event.ToDeviceEventType { logrus.Warn("ignoring a not to-device event") @@ -171,7 +171,7 @@ func (f *SFU) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { } } -func (f *SFU) createSDPAnswerEvent( +func (f *SignalingServer) createSDPAnswerEvent( conferenceID string, destSessionID id.SessionID, peerID peer.ID, @@ -208,7 +208,7 @@ func createBaseEventContent( } // Sends a to-device event to the given user. -func (f *SFU) sendToDevice(participantID peer.ID, ev *event.Event) { +func (f *SignalingServer) sendToDevice(participantID peer.ID, ev *event.Event) { // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ "user_id": participantID.UserID, From ba6eb3f73714acad4a9937200540e5f2ee69b45c Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 22 Nov 2022 17:53:08 +0100 Subject: [PATCH 05/62] refactor: finalize signaling and peer communication --- config.yaml.sample | 10 +- docker-compose.yaml | 10 +- src/conference/conference.go | 118 ++++++-------- src/conference/config.go | 8 + src/conference/messages.go | 179 ++++++++++++++++++++++ src/conference/participant.go | 91 +++++++++++ src/config/config.go | 25 +-- src/main.go | 10 +- src/peer/channel.go | 19 +-- src/peer/peer.go | 21 ++- src/peer/webrtc.go | 18 +-- src/{signaling/signaling.go => router.go} | 116 +++----------- src/signaling/config.go | 13 ++ src/signaling/matrix.go | 156 +++++++++++++++++-- 14 files changed, 572 insertions(+), 222 deletions(-) create mode 100644 src/conference/config.go create mode 100644 src/conference/messages.go create mode 100644 src/conference/participant.go rename src/{signaling/signaling.go => router.go} (54%) create mode 100644 src/signaling/config.go diff --git a/config.yaml.sample b/config.yaml.sample index 0b1b4df..e53b2e1 100644 --- a/config.yaml.sample +++ b/config.yaml.sample @@ -1,4 +1,6 @@ -homeserverurl: "http://localhost:8008" -userid: "@sfu:shadowfax" -accesstoken: "..." -timeout: 30 +matrix: + homeserverurl: "http://localhost:8008" + userid: "@sfu:shadowfax" + accesstoken: "..." +conference: + timeout: 30 diff --git a/docker-compose.yaml b/docker-compose.yaml index 90ffc6d..83d75e0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -9,7 +9,9 @@ services: environment: # Set the `CONFIG` to the configuration you want. CONFIG: | - homeserverurl: "http://localhost:8008" - userid: "@sfu:shadowfax" - accesstoken: "..." - timeout: 30 + matrix: + homeserverurl: "http://localhost:8008" + userid: "@sfu:shadowfax" + accesstoken: "..." + conference: + timeout: 30 diff --git a/src/conference/conference.go b/src/conference/conference.go index 290e7f9..1a2a243 100644 --- a/src/conference/conference.go +++ b/src/conference/conference.go @@ -18,46 +18,33 @@ package conference import ( "github.com/matrix-org/waterfall/src/peer" + "github.com/matrix-org/waterfall/src/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) -// Configuration for the group conferences (calls). -type CallConfig struct { - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. - KeepAliveTimeout int -} - -type Participant struct { - Peer *peer.Peer - Data *ParticipantData -} - -type ParticipantData struct { - RemoteSessionID id.SessionID - StreamMetadata event.CallSDPStreamMetadata -} - type Conference struct { - conferenceID string - config *CallConfig - participants map[peer.ID]*Participant - participantsChannel peer.MessageChannel - logger *logrus.Entry + id string + config Config + signaling signaling.MatrixSignaling + participants map[peer.ID]*Participant + peerEventsStream chan peer.Message + logger *logrus.Entry } -func NewConference(confID string, config *CallConfig) *Conference { - conference := new(Conference) - conference.config = config - conference.conferenceID = confID - conference.participants = make(map[peer.ID]*Participant) - conference.participantsChannel = make(peer.MessageChannel) - conference.logger = logrus.WithFields(logrus.Fields{ - "conf_id": confID, - }) +func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference { + conference := &Conference{ + id: confID, + config: config, + signaling: signaling, + participants: make(map[peer.ID]*Participant), + peerEventsStream: make(chan peer.Message), + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + } + + // Start conference "main loop". + go conference.processMessages() return conference } @@ -66,41 +53,43 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event. // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. // TODO: Implement this. - /* - for _, participant := range c.participants { - if participant.data.DeviceID == inviteEvent.DeviceID { - if participant.data.RemoteSessionID == inviteEvent.SenderSessionID { - c.logger.WithFields(logrus.Fields{ - "device_id": inviteEvent.DeviceID, - "session_id": inviteEvent.SenderSessionID, - }).Errorf("Found existing participant with equal DeviceID and SessionID") - return - } else { - participant.Terminate() - delete(c.participants, participant.data.UserID) - } + for id, participant := range c.participants { + if id.DeviceID == inviteEvent.DeviceID { + if participant.remoteSessionID == inviteEvent.SenderSessionID { + c.logger.WithFields(logrus.Fields{ + "device_id": inviteEvent.DeviceID, + "session_id": inviteEvent.SenderSessionID, + }).Errorf("Found existing participant with equal DeviceID and SessionID") + return + } else { + participant.peer.Terminate() } } - */ + } - peer, _, err := peer.NewPeer(participantID, c.conferenceID, inviteEvent.Offer.SDP, c.participantsChannel) + peer, sdpOffer, err := peer.NewPeer(participantID, c.id, inviteEvent.Offer.SDP, c.peerEventsStream) if err != nil { c.logger.WithError(err).Errorf("Failed to create new peer") return } - participantData := &ParticipantData{ - RemoteSessionID: inviteEvent.SenderSessionID, - StreamMetadata: inviteEvent.SDPStreamMetadata, + participant := &Participant{ + id: participantID, + peer: peer, + remoteSessionID: inviteEvent.SenderSessionID, + streamMetadata: inviteEvent.SDPStreamMetadata, + publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), } - c.participants[participantID] = &Participant{Peer: peer, Data: participantData} + c.participants[participantID] = participant - // TODO: Send the SDP answer back to the participant's device. + recipient := participant.asMatrixRecipient() + streamMetadata := c.getStreamsMetadata(participantID) + c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP) } func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) { - if participant := c.getParticipant(peerID); participant != nil { + if participant := c.getParticipant(peerID, nil); participant != nil { // Convert the candidates to the WebRTC format. candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates)) for i, candidate := range candidatesEvent.Candidates { @@ -112,36 +101,23 @@ func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCan } } - participant.Peer.AddICECandidates(candidates) + participant.peer.AddICECandidates(candidates) } } func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) { - if participant := c.getParticipant(peerID); participant != nil { + if participant := c.getParticipant(peerID, nil); participant != nil { if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() { c.logger.WithFields(logrus.Fields{ "device_id": selectAnswerEvent.SelectedPartyID, }).Errorf("Call was answered on a different device, kicking this peer") - participant.Peer.Terminate() + participant.peer.Terminate() } } } func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) { - if participant := c.getParticipant(peerID); participant != nil { - participant.Peer.Terminate() - } -} - -func (c *Conference) getParticipant(peerID peer.ID) *Participant { - participant, ok := c.participants[peerID] - if !ok { - c.logger.WithFields(logrus.Fields{ - "user_id": peerID.UserID, - "device_id": peerID.DeviceID, - }).Errorf("Failed to find participant") - return nil + if participant := c.getParticipant(peerID, nil); participant != nil { + participant.peer.Terminate() } - - return participant } diff --git a/src/conference/config.go b/src/conference/config.go new file mode 100644 index 0000000..81952af --- /dev/null +++ b/src/conference/config.go @@ -0,0 +1,8 @@ +package conference + +// Configuration for the group conferences (calls). +type Config struct { + // Keep-alive timeout for WebRTC connections. If no keep-alive has been received + // from the client for this duration, the connection is considered dead. + KeepAliveTimeout int `yaml:"timeout"` +} diff --git a/src/conference/messages.go b/src/conference/messages.go new file mode 100644 index 0000000..e4b30ea --- /dev/null +++ b/src/conference/messages.go @@ -0,0 +1,179 @@ +package conference + +import ( + "encoding/json" + "errors" + + "github.com/matrix-org/waterfall/src/peer" + "maunium.net/go/mautrix/event" +) + +func (c *Conference) processMessages() { + for { + // Read a message from the stream (of type peer.Message) and process it. + message := <-c.peerEventsStream + c.processPeerMessage(message) + } +} + +//nolint:funlen +func (c *Conference) processPeerMessage(message peer.Message) { + // Since Go does not support ADTs, we have to use a switch statement to + // determine the actual type of the message. + switch msg := message.(type) { + case peer.JoinedTheCall: + case peer.LeftTheCall: + delete(c.participants, msg.Sender) + // TODO: Send new metadata about available streams to all participants. + // TODO: Send the hangup event over the Matrix back to the user. + + case peer.NewTrackPublished: + participant := c.getParticipant(msg.Sender, errors.New("New track published from unknown participant")) + if participant == nil { + return + } + + key := event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + } + + if _, ok := participant.publishedTracks[key]; ok { + c.logger.Errorf("Track already published: %v", key) + return + } + + participant.publishedTracks[key] = msg.Track + + case peer.PublishedTrackFailed: + participant := c.getParticipant(msg.Sender, errors.New("Published track failed from unknown participant")) + if participant == nil { + return + } + + delete(participant.publishedTracks, event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + }) + + // TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically? + + case peer.NewICECandidate: + participant := c.getParticipant(msg.Sender, errors.New("ICE candidate from unknown participant")) + if participant == nil { + return + } + + // Convert WebRTC ICE candidate to Matrix ICE candidate. + jsonCandidate := msg.Candidate.ToJSON() + candidates := []event.CallCandidate{{ + Candidate: jsonCandidate.Candidate, + SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), + SDPMID: *jsonCandidate.SDPMid, + }} + c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) + + case peer.ICEGatheringComplete: + participant := c.getParticipant(msg.Sender, errors.New("Received ICE complete from unknown participant")) + if participant == nil { + return + } + + // Send an empty array of candidates to indicate that ICE gathering is complete. + c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) + + case peer.RenegotiationRequired: + participant := c.getParticipant(msg.Sender, errors.New("Renegotiation from unknown participant")) + if participant == nil { + return + } + + toSend := event.SFUMessage{ + Op: event.SFUOperationOffer, + SDP: msg.Offer.SDP, + Metadata: c.getStreamsMetadata(participant.id), + } + + participant.sendDataChannelMessage(toSend) + + case peer.DataChannelMessage: + participant := c.getParticipant(msg.Sender, errors.New("Data channel message from unknown participant")) + if participant == nil { + return + } + + var sfuMessage event.SFUMessage + if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { + c.logger.Errorf("Failed to unmarshal SFU message: %v", err) + return + } + + c.handleDataChannelMessage(participant, sfuMessage) + + case peer.DataChannelAvailable: + participant := c.getParticipant(msg.Sender, errors.New("Data channel available from unknown participant")) + if participant == nil { + return + } + + toSend := event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getStreamsMetadata(participant.id), + } + + if err := participant.sendDataChannelMessage(toSend); err != nil { + c.logger.Errorf("Failed to send SFU message to open data channel: %v", err) + return + } + + default: + c.logger.Errorf("Unknown message type: %T", msg) + } +} + +// Handle the `SFUMessage` event from the DataChannel message. +func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessage event.SFUMessage) { + switch sfuMessage.Op { + case event.SFUOperationSelect: + // Get the tracks that correspond to the tracks that the participant wants to receive. + for _, track := range c.getTracks(sfuMessage.Start) { + if err := participant.peer.SubscribeToTrack(track); err != nil { + c.logger.Errorf("Failed to subscribe to track: %v", err) + return + } + } + + case event.SFUOperationAnswer: + if err := participant.peer.NewSDPAnswerReceived(sfuMessage.SDP); err != nil { + c.logger.Errorf("Failed to set SDP answer: %v", err) + return + } + + // TODO: Clarify the semantics of publish (just a new sdp offer?). + case event.SFUOperationPublish: + // TODO: Clarify the semantics of publish (how is it different from unpublish?). + case event.SFUOperationUnpublish: + // TODO: Handle the heartbeat message here (updating the last timestamp etc). + case event.SFUOperationAlive: + case event.SFUOperationMetadata: + participant.streamMetadata = sfuMessage.Metadata + + // Inform all participants about new metadata available. + for id, participant := range c.participants { + // Skip ourselves. + if id == participant.id { + continue + } + + toSend := event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getStreamsMetadata(id), + } + + if err := participant.sendDataChannelMessage(toSend); err != nil { + c.logger.Errorf("Failed to send SFU message: %v", err) + return + } + } + } +} diff --git a/src/conference/participant.go b/src/conference/participant.go new file mode 100644 index 0000000..fc0ced3 --- /dev/null +++ b/src/conference/participant.go @@ -0,0 +1,91 @@ +package conference + +import ( + "encoding/json" + "errors" + + "github.com/matrix-org/waterfall/src/peer" + "github.com/matrix-org/waterfall/src/signaling" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ErrInvalidSFUMessage = errors.New("invalid SFU message") + +type Participant struct { + id peer.ID + peer *peer.Peer + remoteSessionID id.SessionID + streamMetadata event.CallSDPStreamMetadata + publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP +} + +func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { + return signaling.MatrixRecipient{ + ID: p.id, + RemoteSessionID: p.remoteSessionID, + } +} + +func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error { + jsonToSend, err := json.Marshal(toSend) + if err != nil { + return ErrInvalidSFUMessage + } + + if err := p.peer.SendOverDataChannel(string(jsonToSend)); err != nil { + // FIXME: We must buffer the message in this case and re-send it once the data channel is recovered! + // Or use Matrix signaling to inform the peer about the problem. + return err + } + + return nil +} + +func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[peerID] + if !ok { + logEntry := c.logger.WithFields(logrus.Fields{ + "user_id": peerID.UserID, + "device_id": peerID.DeviceID, + }) + + if optionalErrorMessage != nil { + logEntry.WithError(optionalErrorMessage) + } else { + logEntry.Error("Participant not found") + } + + return nil + } + + return participant +} + +func (c *Conference) getStreamsMetadata(forParticipant peer.ID) event.CallSDPStreamMetadata { + streamsMetadata := make(event.CallSDPStreamMetadata) + for id, participant := range c.participants { + if forParticipant != id { + for streamID, metadata := range participant.streamMetadata { + streamsMetadata[streamID] = metadata + } + } + } + + return streamsMetadata +} + +func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { + tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) + for _, participant := range c.participants { + // Check if this participant has any of the tracks that we're looking for. + for _, identifier := range identifiers { + if track, ok := participant.publishedTracks[identifier]; ok { + tracks = append(tracks, track) + } + } + } + return tracks +} diff --git a/src/config/config.go b/src/config/config.go index 691ace9..d9adc83 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -5,22 +5,18 @@ import ( "fmt" "os" + "github.com/matrix-org/waterfall/src/conference" + "github.com/matrix-org/waterfall/src/signaling" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - "maunium.net/go/mautrix/id" ) -// The mandatory SFU configuration. +// SFU configuration. type Config struct { - // The Matrix ID (MXID) of the SFU. - UserID id.UserID - // The ULR of the homeserver that SFU talks to. - HomeserverURL string - // The access token for the Matrix SDK. - AccessToken string - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. - KeepAliveTimeout int + // Matrix configuration. + Matrix signaling.Config `yaml:"matrix"` + // Conference (call) configuration. + Conference conference.Config `yaml:"conference"` } // Tries to load a config from the `CONFIG` environment variable. @@ -76,5 +72,12 @@ func LoadConfigFromString(configString string) (*Config, error) { return nil, fmt.Errorf("failed to unmarshal YAML file: %w", err) } + if config.Matrix.UserID == "" || + config.Matrix.HomeserverURL == "" || + config.Matrix.AccessToken == "" || + config.Conference.KeepAliveTimeout == 0 { + return nil, errors.New("invalid config values") + } + return &config, nil } diff --git a/src/main.go b/src/main.go index 320e6a0..cfffc62 100644 --- a/src/main.go +++ b/src/main.go @@ -65,7 +65,15 @@ func main() { config, err := config.LoadConfig(*configFilePath) if err != nil { logrus.WithError(err).Fatal("could not load config") + return } - signaling.RunServer(config) + // Create matrix client. + matrixClient := signaling.NewMatrixClient(config.Matrix) + + // Create a router to route incoming To-Device messages to the right conference. + router := newRouter(matrixClient, config.Conference) + + // Start matrix client sync. This function will block until the sync fails. + matrixClient.RunSync(router.handleMatrixEvent) } diff --git a/src/peer/channel.go b/src/peer/channel.go index 0f5e127..b0c573c 100644 --- a/src/peer/channel.go +++ b/src/peer/channel.go @@ -4,13 +4,13 @@ import ( "github.com/pion/webrtc/v3" ) -type MessageChannel chan interface{} +type Message = interface{} -type PeerJoinedTheCall struct { +type JoinedTheCall struct { Sender ID } -type PeerLeftTheCall struct { +type LeftTheCall struct { Sender ID } @@ -33,25 +33,16 @@ type ICEGatheringComplete struct { Sender ID } -type NewOffer struct { +type RenegotiationRequired struct { Sender ID Offer *webrtc.SessionDescription } -type DataChannelOpened struct { - Sender ID -} - -type DataChannelClosed struct { - Sender ID -} - type DataChannelMessage struct { Sender ID Message string } -type DataChannelError struct { +type DataChannelAvailable struct { Sender ID - Err error } diff --git a/src/peer/peer.go b/src/peer/peer.go index e9a3b8b..84506d0 100644 --- a/src/peer/peer.go +++ b/src/peer/peer.go @@ -15,6 +15,7 @@ var ( ErrCantSetLocalDescription = errors.New("can't set local description") ErrCantCreateLocalDescription = errors.New("can't create local description") ErrDataChannelNotAvailable = errors.New("data channel is not available") + ErrDataChannelNotReady = errors.New("data channel is not ready") ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) @@ -103,7 +104,7 @@ func (p *Peer) Terminate() { p.logger.WithError(err).Error("failed to close peer connection") } - p.notify <- PeerLeftTheCall{Sender: p.id} + p.notify <- LeftTheCall{Sender: p.id} } func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { @@ -133,9 +134,27 @@ func (p *Peer) SendOverDataChannel(json string) error { return ErrDataChannelNotAvailable } + if p.dataChannel.ReadyState() != webrtc.DataChannelStateOpen { + p.logger.Error("can't send data over data channel: data channel is not open") + return ErrDataChannelNotReady + } + if err := p.dataChannel.SendText(json); err != nil { p.logger.WithError(err).Error("failed to send data over data channel") } return nil } + +func (p *Peer) NewSDPAnswerReceived(sdpAnswer string) error { + err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: sdpAnswer, + }) + if err != nil { + p.logger.WithError(err).Error("failed to set remote description") + return ErrCantSetRemoteDecsription + } + + return nil +} diff --git a/src/peer/webrtc.go b/src/peer/webrtc.go index 4d97a80..889416c 100644 --- a/src/peer/webrtc.go +++ b/src/peer/webrtc.go @@ -93,7 +93,7 @@ func (p *Peer) onNegotiationNeeded() { return } - p.notify <- NewOffer{Sender: p.id, Offer: &offer} + p.notify <- RenegotiationRequired{Sender: p.id, Offer: &offer} } // A callback that is called once we receive an ICE connection state change for this peer connection. @@ -132,9 +132,9 @@ func (p *Peer) onConnectionStateChanged(state webrtc.PeerConnectionState) { switch state { case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: - p.notify <- PeerLeftTheCall{Sender: p.id} + p.notify <- LeftTheCall{Sender: p.id} case webrtc.PeerConnectionStateConnected: - p.notify <- PeerJoinedTheCall{Sender: p.id} + p.notify <- JoinedTheCall{Sender: p.id} } } @@ -154,12 +154,7 @@ func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { dc.OnOpen(func() { p.logger.Info("data channel opened") - p.notify <- DataChannelOpened{Sender: p.id} - }) - - dc.OnClose(func() { - p.logger.Info("data channel closed") - p.notify <- DataChannelClosed{Sender: p.id} + p.notify <- DataChannelAvailable{Sender: p.id} }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { @@ -173,6 +168,9 @@ func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { dc.OnError(func(err error) { p.logger.WithError(err).Error("data channel error") - p.notify <- DataChannelError{Sender: p.id, Err: err} + }) + + dc.OnClose(func() { + p.logger.Info("data channel closed") }) } diff --git a/src/signaling/signaling.go b/src/router.go similarity index 54% rename from src/signaling/signaling.go rename to src/router.go index ab9ea6e..e7b3899 100644 --- a/src/signaling/signaling.go +++ b/src/router.go @@ -14,53 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ -package signaling +package main import ( - "errors" - "github.com/matrix-org/waterfall/src/conference" "github.com/matrix-org/waterfall/src/peer" + "github.com/matrix-org/waterfall/src/signaling" "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) -var ErrNoSuchConference = errors.New("no such conference") - -// The top-level state of the SignalingServer. -// Note that in Matrix MSCs, the term "focus" is used to refer to the SignalingServer. But since "focus" is a very -// generic name and only makes sense in a certain context, we use the term "SignalingServer" instead to avoid confusion -// given that this particular part is just the SignalingServer logic (and not the "focus" selection algorithm etc). -type SignalingServer struct { - // Matrix client. - client *mautrix.Client +// The top-level state of the Router. +type Router struct { + // Matrix matrix. + matrix *signaling.MatrixClient // All calls currently forwarded by this SFU. conferences map[string]*conference.Conference // Configuration for the calls. - config *conference.CallConfig + config conference.Config } // Creates a new instance of the SFU with the given configuration. -func NewSignalingServer(client *mautrix.Client, config *conference.CallConfig) *SignalingServer { - return &SignalingServer{ - client: client, +func newRouter(matrix *signaling.MatrixClient, config conference.Config) *Router { + return &Router{ + matrix: matrix, conferences: make(map[string]*conference.Conference), config: config, } } -// Handles To-Device events that the SFU receives from clients. +// Handles incoming To-Device events that the SFU receives from clients. // //nolint:funlen -func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) { - // We only care about to-device events. - if evt.Type.Class != event.ToDeviceEventType { - logrus.Warn("ignoring a not to-device event") - return - } - +func (r *Router) handleMatrixEvent(evt *event.Event) { // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ "type": evt.Type.Type, @@ -68,11 +54,6 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) "conf_id": evt.Content.Raw["conf_id"], }) - if evt.Content.Raw["dest_session_id"] != LocalSessionID { - logger.WithField("dest_session_id", LocalSessionID).Warn("SessionID does not match our SessionID - ignoring") - return - } - switch evt.Type.Type { // Someone tries to participate in a call (join a call). case event.ToDeviceCallInvite.Type: @@ -83,9 +64,13 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) } // If there is an invitation sent and the conf does not exist, create one. - if conf := f.conferences[invite.ConfID]; conf == nil { + if conf := r.conferences[invite.ConfID]; conf == nil { logger.Infof("creating new conference %s", invite.ConfID) - f.conferences[invite.ConfID] = conference.NewConference(invite.ConfID, f.config) + r.conferences[invite.ConfID] = conference.NewConference( + invite.ConfID, + r.config, + r.matrix.CreateForConference(invite.ConfID), + ) } peerID := peer.ID{ @@ -94,7 +79,7 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) } // Inform conference about incoming participant. - f.conferences[invite.ConfID].OnNewParticipant(peerID, invite) + r.conferences[invite.ConfID].OnNewParticipant(peerID, invite) // Someone tries to send ICE candidates to the existing call. case event.ToDeviceCallCandidates.Type: @@ -104,7 +89,7 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) return } - conference := f.conferences[candidates.ConfID] + conference := r.conferences[candidates.ConfID] if conference == nil { logger.Errorf("received candidates for unknown conference %s", candidates.ConfID) return @@ -125,7 +110,7 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) return } - conference := f.conferences[selectAnswer.ConfID] + conference := r.conferences[selectAnswer.ConfID] if conference == nil { logger.Errorf("received select_answer for unknown conference %s", selectAnswer.ConfID) return @@ -146,7 +131,7 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) return } - conference := f.conferences[hangup.ConfID] + conference := r.conferences[hangup.ConfID] if conference == nil { logger.Errorf("received hangup for unknown conference %s", hangup.ConfID) return @@ -170,60 +155,3 @@ func (f *SignalingServer) onMatrixEvent(_ mautrix.EventSource, evt *event.Event) logger.Warnf("ignoring unexpected event: %s", evt.Type.Type) } } - -func (f *SignalingServer) createSDPAnswerEvent( - conferenceID string, - destSessionID id.SessionID, - peerID peer.ID, - sdp string, - streamMetadata event.CallSDPStreamMetadata, -) *event.Content { - return &event.Content{ - Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: createBaseEventContent(conferenceID, f.client.DeviceID, peerID.DeviceID, destSessionID), - Answer: event.CallData{ - Type: "answer", - SDP: sdp, - }, - SDPStreamMetadata: streamMetadata, - }, - } -} - -func createBaseEventContent( - conferenceID string, - sfuDeviceID id.DeviceID, - destDeviceID id.DeviceID, - destSessionID id.SessionID, -) event.BaseCallEventContent { - return event.BaseCallEventContent{ - CallID: conferenceID, - ConfID: conferenceID, - DeviceID: sfuDeviceID, - SenderSessionID: LocalSessionID, - DestSessionID: destSessionID, - PartyID: string(destDeviceID), - Version: event.CallVersion("1"), - } -} - -// Sends a to-device event to the given user. -func (f *SignalingServer) sendToDevice(participantID peer.ID, ev *event.Event) { - // TODO: Don't create logger again and again, it might be a bit expensive. - logger := logrus.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - }) - - sendRequest := &mautrix.ReqSendToDevice{ - Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - participantID.UserID: { - participantID.DeviceID: &ev.Content, - }, - }, - } - - if _, err := f.client.SendToDevice(ev.Type, sendRequest); err != nil { - logger.Errorf("failed to send to-device event: %w", err) - } -} diff --git a/src/signaling/config.go b/src/signaling/config.go new file mode 100644 index 0000000..acd730c --- /dev/null +++ b/src/signaling/config.go @@ -0,0 +1,13 @@ +package signaling + +import "maunium.net/go/mautrix/id" + +// Configuration for the Matrix client. +type Config struct { + // The Matrix ID (MXID) of the SFU. + UserID id.UserID `yaml:"userid"` + // The ULR of the homeserver that SFU talks to. + HomeserverURL string `yaml:"homeserverurl"` + // The access token for the Matrix SDK. + AccessToken string `yaml:"accesstoken"` +} diff --git a/src/signaling/matrix.go b/src/signaling/matrix.go index d2d43a2..3c56a27 100644 --- a/src/signaling/matrix.go +++ b/src/signaling/matrix.go @@ -17,17 +17,20 @@ limitations under the License. package signaling import ( - "github.com/matrix-org/waterfall/src/conference" - "github.com/matrix-org/waterfall/src/config" + "github.com/matrix-org/waterfall/src/peer" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) const LocalSessionID = "sfu" -// Starts the Matrix client and connects to the homeserver, -// runs the SFU. Returns only when the sync with Matrix fails. -func RunServer(config *config.Config) { +type MatrixClient struct { + client *mautrix.Client +} + +func NewMatrixClient(config Config) *MatrixClient { client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) if err != nil { logrus.WithError(err).Fatal("Failed to create client") @@ -45,23 +48,152 @@ func RunServer(config *config.Config) { logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") client.DeviceID = whoami.DeviceID - focus := NewSignalingServer( - client, - &conference.CallConfig{KeepAliveTimeout: config.KeepAliveTimeout}, - ) + return &MatrixClient{ + client: client, + } +} - syncer, ok := client.Syncer.(*mautrix.DefaultSyncer) +// Starts the Matrix client and connects to the homeserver, +// Returns only when the sync with Matrix fails. +func (m *MatrixClient) RunSync(callback func(*event.Event)) { + syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) if !ok { logrus.Panic("Syncer is not DefaultSyncer") } syncer.ParseEventContent = true - syncer.OnEvent(focus.onMatrixEvent) + syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { + // We only care about to-device events. + if evt.Type.Class != event.ToDeviceEventType { + logrus.Warn("ignoring a not to-device event") + return + } + + // We drop the messages if they are not meant for us. + if evt.Content.Raw["dest_session_id"] != LocalSessionID { + logrus.Warn("SessionID does not match our SessionID - ignoring") + return + } + + callback(evt) + }) // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU // as ending here will essentially drop all conferences which may not necessarily // be what we want for the existing running conferences. - if err = client.Sync(); err != nil { + if err := m.client.Sync(); err != nil { logrus.WithError(err).Panic("Sync failed") } } + +func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConference { + return &MatrixForConference{ + client: m.client, + conferenceID: conferenceID, + } +} + +type MatrixRecipient struct { + ID peer.ID + RemoteSessionID id.SessionID +} + +type MatrixSignaling interface { + SendSDPAnswer(recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, sdp string) + SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) + SendCandidatesGatheringFinished(recipient MatrixRecipient) + SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) +} + +type MatrixForConference struct { + client *mautrix.Client + conferenceID string +} + +func (m *MatrixForConference) SendSDPAnswer( + recipient MatrixRecipient, + streamMetadata event.CallSDPStreamMetadata, + sdp string, +) { + eventContent := &event.Content{ + Parsed: event.CallAnswerEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + Answer: event.CallData{ + Type: "answer", + SDP: sdp, + }, + SDPStreamMetadata: streamMetadata, + }, + } + + m.sendToDevice(recipient.ID, event.CallAnswer, eventContent) +} + +func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { + eventContent := &event.Content{ + Parsed: event.CallCandidatesEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + Candidates: candidates, + }, + } + + m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) +} + +func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { + eventContent := &event.Content{ + Parsed: event.CallCandidatesEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + Candidates: []event.CallCandidate{{Candidate: ""}}, + }, + } + + m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) +} + +func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { + eventContent := &event.Content{ + Parsed: event.CallHangupEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + Reason: reason, + }, + } + + m.sendToDevice(recipient.ID, event.CallHangup, eventContent) +} + +func (m *MatrixForConference) createBaseEventContent( + destDeviceID id.DeviceID, + destSessionID id.SessionID, +) event.BaseCallEventContent { + return event.BaseCallEventContent{ + CallID: m.conferenceID, + ConfID: m.conferenceID, + DeviceID: m.client.DeviceID, + SenderSessionID: LocalSessionID, + DestSessionID: destSessionID, + PartyID: string(destDeviceID), + Version: event.CallVersion("1"), + } +} + +// Sends a to-device event to the given user. +func (m *MatrixForConference) sendToDevice(participantID peer.ID, eventType event.Type, eventContent *event.Content) { + // TODO: Don't create logger again and again, it might be a bit expensive. + logger := logrus.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + sendRequest := &mautrix.ReqSendToDevice{ + Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + participantID.UserID: { + participantID.DeviceID: eventContent, + }, + }, + } + + if _, err := m.client.SendToDevice(eventType, sendRequest); err != nil { + logger.Errorf("failed to send to-device event: %w", err) + } +} From abef0c55bcb40c3fb67ed5940d5765c3fb149715 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 17:37:32 +0100 Subject: [PATCH 06/62] refactor: define a package for message sink This allows to generalize the message sink and get rid of a lot of copy-paste in the handling functions. Also this moves types to the right modules, so that `peer` is now completely matrix-unaware module that contains only plain WebRTC logic. --- src/common/message_sink.go | 38 +++++++ src/conference/conference.go | 58 ++++++----- src/conference/participant.go | 22 ++-- src/conference/{messages.go => processor.go} | 50 ++-------- src/peer/channel.go | 48 --------- src/peer/id.go | 8 -- src/peer/messages.go | 35 +++++++ src/peer/peer.go | 38 +++---- src/peer/webrtc.go | 38 +++---- src/router.go | 25 +++-- src/signaling/client.go | 67 +++++++++++++ src/signaling/matrix.go | 100 ++++--------------- 12 files changed, 266 insertions(+), 261 deletions(-) create mode 100644 src/common/message_sink.go rename src/conference/{messages.go => processor.go} (77%) delete mode 100644 src/peer/channel.go delete mode 100644 src/peer/id.go create mode 100644 src/peer/messages.go create mode 100644 src/signaling/client.go diff --git a/src/common/message_sink.go b/src/common/message_sink.go new file mode 100644 index 0000000..3351492 --- /dev/null +++ b/src/common/message_sink.go @@ -0,0 +1,38 @@ +package common + +// MessageSink is a helper struct that allows to send messages to a message sink. +// The MessageSink abstracts the message sink which has a certain sender, so that +// the sender does not have to be specified every time a message is sent. +// At the same it guarantees that the caller can't alter the `sender`, which means that +// the sender can't impersonate another sender (and we guarantee this on a compile-time). +type MessageSink[SenderType comparable, MessageType any] struct { + // The sender of the messages. This is useful for multiple-producer-single-consumer scenarios. + sender SenderType + // The message sink to which the messages are sent. + messageSink chan<- Message[SenderType, MessageType] +} + +// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. +func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] { + return &MessageSink[S, M]{ + sender: sender, + messageSink: messageSink, + } +} + +// Sends a message to the message sink. +func (s *MessageSink[S, M]) Send(message M) { + s.messageSink <- Message[S, M]{ + Sender: s.sender, + Content: message, + } +} + +// Messages that are sent from the peer to the conference in order to communicate with other peers. +// Since each peer is isolated from others, it can't influence the state of other peers directly. +type Message[SenderType comparable, MessageType any] struct { + // The sender of the message. + Sender SenderType + // The content of the message. + Content MessageType +} diff --git a/src/conference/conference.go b/src/conference/conference.go index 1a2a243..56ef684 100644 --- a/src/conference/conference.go +++ b/src/conference/conference.go @@ -17,6 +17,7 @@ limitations under the License. package conference import ( + "github.com/matrix-org/waterfall/src/common" "github.com/matrix-org/waterfall/src/peer" "github.com/matrix-org/waterfall/src/signaling" "github.com/pion/webrtc/v3" @@ -25,22 +26,22 @@ import ( ) type Conference struct { - id string - config Config - signaling signaling.MatrixSignaling - participants map[peer.ID]*Participant - peerEventsStream chan peer.Message - logger *logrus.Entry + id string + config Config + signaling signaling.MatrixSignaling + participants map[ParticipantID]*Participant + peerEvents chan common.Message[ParticipantID, peer.MessageContent] + logger *logrus.Entry } func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference { conference := &Conference{ - id: confID, - config: config, - signaling: signaling, - participants: make(map[peer.ID]*Participant), - peerEventsStream: make(chan peer.Message), - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + id: confID, + config: config, + signaling: signaling, + participants: make(map[ParticipantID]*Participant), + peerEvents: make(chan common.Message[ParticipantID, peer.MessageContent]), + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } // Start conference "main loop". @@ -49,7 +50,7 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna } // New participant tries to join the conference. -func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event.CallInviteEventContent) { +func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) { // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. // TODO: Implement this. @@ -67,7 +68,16 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event. } } - peer, sdpOffer, err := peer.NewPeer(participantID, c.id, inviteEvent.Offer.SDP, c.peerEventsStream) + var ( + participantlogger = logrus.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + "conf_id": c.id, + }) + messageSink = common.NewMessageSink(participantID, c.peerEvents) + ) + + peer, sdpOffer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, participantlogger) if err != nil { c.logger.WithError(err).Errorf("Failed to create new peer") return @@ -88,11 +98,11 @@ func (c *Conference) OnNewParticipant(participantID peer.ID, inviteEvent *event. c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP) } -func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCandidatesEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { +func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { // Convert the candidates to the WebRTC format. - candidates := make([]webrtc.ICECandidateInit, len(candidatesEvent.Candidates)) - for i, candidate := range candidatesEvent.Candidates { + candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) + for i, candidate := range ev.Candidates { SDPMLineIndex := uint16(candidate.SDPMLineIndex) candidates[i] = webrtc.ICECandidateInit{ Candidate: candidate.Candidate, @@ -105,19 +115,19 @@ func (c *Conference) OnCandidates(peerID peer.ID, candidatesEvent *event.CallCan } } -func (c *Conference) OnSelectAnswer(peerID peer.ID, selectAnswerEvent *event.CallSelectAnswerEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { - if selectAnswerEvent.SelectedPartyID != peerID.DeviceID.String() { +func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + if ev.SelectedPartyID != participantID.DeviceID.String() { c.logger.WithFields(logrus.Fields{ - "device_id": selectAnswerEvent.SelectedPartyID, + "device_id": ev.SelectedPartyID, }).Errorf("Call was answered on a different device, kicking this peer") participant.peer.Terminate() } } } -func (c *Conference) OnHangup(peerID peer.ID, hangupEvent *event.CallHangupEventContent) { - if participant := c.getParticipant(peerID, nil); participant != nil { +func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { participant.peer.Terminate() } } diff --git a/src/conference/participant.go b/src/conference/participant.go index fc0ced3..08b5f88 100644 --- a/src/conference/participant.go +++ b/src/conference/participant.go @@ -14,9 +14,14 @@ import ( var ErrInvalidSFUMessage = errors.New("invalid SFU message") +type ParticipantID struct { + UserID id.UserID + DeviceID id.DeviceID +} + type Participant struct { - id peer.ID - peer *peer.Peer + id ParticipantID + peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP @@ -24,7 +29,8 @@ type Participant struct { func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { return signaling.MatrixRecipient{ - ID: p.id, + UserID: p.id.UserID, + DeviceID: p.id.DeviceID, RemoteSessionID: p.remoteSessionID, } } @@ -44,12 +50,12 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error { return nil } -func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) *Participant { - participant, ok := c.participants[peerID] +func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[participantID] if !ok { logEntry := c.logger.WithFields(logrus.Fields{ - "user_id": peerID.UserID, - "device_id": peerID.DeviceID, + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, }) if optionalErrorMessage != nil { @@ -64,7 +70,7 @@ func (c *Conference) getParticipant(peerID peer.ID, optionalErrorMessage error) return participant } -func (c *Conference) getStreamsMetadata(forParticipant peer.ID) event.CallSDPStreamMetadata { +func (c *Conference) getStreamsMetadata(forParticipant ParticipantID) event.CallSDPStreamMetadata { streamsMetadata := make(event.CallSDPStreamMetadata) for id, participant := range c.participants { if forParticipant != id { diff --git a/src/conference/messages.go b/src/conference/processor.go similarity index 77% rename from src/conference/messages.go rename to src/conference/processor.go index e4b30ea..c401079 100644 --- a/src/conference/messages.go +++ b/src/conference/processor.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" + "github.com/matrix-org/waterfall/src/common" "github.com/matrix-org/waterfall/src/peer" "maunium.net/go/mautrix/event" ) @@ -11,28 +12,27 @@ import ( func (c *Conference) processMessages() { for { // Read a message from the stream (of type peer.Message) and process it. - message := <-c.peerEventsStream + message := <-c.peerEvents c.processPeerMessage(message) } } -//nolint:funlen -func (c *Conference) processPeerMessage(message peer.Message) { +func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { + participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) + if participant == nil { + return + } + // Since Go does not support ADTs, we have to use a switch statement to // determine the actual type of the message. - switch msg := message.(type) { + switch msg := message.Content.(type) { case peer.JoinedTheCall: case peer.LeftTheCall: - delete(c.participants, msg.Sender) + delete(c.participants, message.Sender) // TODO: Send new metadata about available streams to all participants. // TODO: Send the hangup event over the Matrix back to the user. case peer.NewTrackPublished: - participant := c.getParticipant(msg.Sender, errors.New("New track published from unknown participant")) - if participant == nil { - return - } - key := event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -46,11 +46,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { participant.publishedTracks[key] = msg.Track case peer.PublishedTrackFailed: - participant := c.getParticipant(msg.Sender, errors.New("Published track failed from unknown participant")) - if participant == nil { - return - } - delete(participant.publishedTracks, event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -59,11 +54,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { // TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically? case peer.NewICECandidate: - participant := c.getParticipant(msg.Sender, errors.New("ICE candidate from unknown participant")) - if participant == nil { - return - } - // Convert WebRTC ICE candidate to Matrix ICE candidate. jsonCandidate := msg.Candidate.ToJSON() candidates := []event.CallCandidate{{ @@ -74,20 +64,10 @@ func (c *Conference) processPeerMessage(message peer.Message) { c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) case peer.ICEGatheringComplete: - participant := c.getParticipant(msg.Sender, errors.New("Received ICE complete from unknown participant")) - if participant == nil { - return - } - // Send an empty array of candidates to indicate that ICE gathering is complete. c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) case peer.RenegotiationRequired: - participant := c.getParticipant(msg.Sender, errors.New("Renegotiation from unknown participant")) - if participant == nil { - return - } - toSend := event.SFUMessage{ Op: event.SFUOperationOffer, SDP: msg.Offer.SDP, @@ -97,11 +77,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { participant.sendDataChannelMessage(toSend) case peer.DataChannelMessage: - participant := c.getParticipant(msg.Sender, errors.New("Data channel message from unknown participant")) - if participant == nil { - return - } - var sfuMessage event.SFUMessage if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { c.logger.Errorf("Failed to unmarshal SFU message: %v", err) @@ -111,11 +86,6 @@ func (c *Conference) processPeerMessage(message peer.Message) { c.handleDataChannelMessage(participant, sfuMessage) case peer.DataChannelAvailable: - participant := c.getParticipant(msg.Sender, errors.New("Data channel available from unknown participant")) - if participant == nil { - return - } - toSend := event.SFUMessage{ Op: event.SFUOperationMetadata, Metadata: c.getStreamsMetadata(participant.id), diff --git a/src/peer/channel.go b/src/peer/channel.go deleted file mode 100644 index b0c573c..0000000 --- a/src/peer/channel.go +++ /dev/null @@ -1,48 +0,0 @@ -package peer - -import ( - "github.com/pion/webrtc/v3" -) - -type Message = interface{} - -type JoinedTheCall struct { - Sender ID -} - -type LeftTheCall struct { - Sender ID -} - -type NewTrackPublished struct { - Sender ID - Track *webrtc.TrackLocalStaticRTP -} - -type PublishedTrackFailed struct { - Sender ID - Track *webrtc.TrackLocalStaticRTP -} - -type NewICECandidate struct { - Sender ID - Candidate *webrtc.ICECandidate -} - -type ICEGatheringComplete struct { - Sender ID -} - -type RenegotiationRequired struct { - Sender ID - Offer *webrtc.SessionDescription -} - -type DataChannelMessage struct { - Sender ID - Message string -} - -type DataChannelAvailable struct { - Sender ID -} diff --git a/src/peer/id.go b/src/peer/id.go deleted file mode 100644 index e7b4697..0000000 --- a/src/peer/id.go +++ /dev/null @@ -1,8 +0,0 @@ -package peer - -import "maunium.net/go/mautrix/id" - -type ID struct { - UserID id.UserID - DeviceID id.DeviceID -} diff --git a/src/peer/messages.go b/src/peer/messages.go new file mode 100644 index 0000000..51ef1d9 --- /dev/null +++ b/src/peer/messages.go @@ -0,0 +1,35 @@ +package peer + +import ( + "github.com/pion/webrtc/v3" +) + +type MessageContent = interface{} + +type JoinedTheCall struct{} + +type LeftTheCall struct{} + +type NewTrackPublished struct { + Track *webrtc.TrackLocalStaticRTP +} + +type PublishedTrackFailed struct { + Track *webrtc.TrackLocalStaticRTP +} + +type NewICECandidate struct { + Candidate *webrtc.ICECandidate +} + +type ICEGatheringComplete struct{} + +type RenegotiationRequired struct { + Offer *webrtc.SessionDescription +} + +type DataChannelMessage struct { + Message string +} + +type DataChannelAvailable struct{} diff --git a/src/peer/peer.go b/src/peer/peer.go index 84506d0..917682b 100644 --- a/src/peer/peer.go +++ b/src/peer/peer.go @@ -4,6 +4,7 @@ import ( "errors" "sync" + "github.com/matrix-org/waterfall/src/common" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" ) @@ -19,39 +20,30 @@ var ( ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) -type Peer struct { - id ID +type Peer[ID comparable] struct { logger *logrus.Entry - notify chan<- interface{} peerConnection *webrtc.PeerConnection + sink *common.MessageSink[ID, MessageContent] dataChannelMutex sync.Mutex dataChannel *webrtc.DataChannel } -func NewPeer( - info ID, - conferenceId string, +func NewPeer[ID comparable]( sdpOffer string, - notify chan<- interface{}, -) (*Peer, *webrtc.SessionDescription, error) { - logger := logrus.WithFields(logrus.Fields{ - "user_id": info.UserID, - "device_id": info.DeviceID, - "conf_id": conferenceId, - }) - + sink *common.MessageSink[ID, MessageContent], + logger *logrus.Entry, +) (*Peer[ID], *webrtc.SessionDescription, error) { peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) if err != nil { logger.WithError(err).Error("failed to create peer connection") return nil, nil, ErrCantCreatePeerConnection } - peer := &Peer{ - id: info, + peer := &Peer[ID]{ logger: logger, - notify: notify, peerConnection: peerConnection, + sink: sink, } peerConnection.OnTrack(peer.onRtpTrackReceived) @@ -99,15 +91,15 @@ func NewPeer( return peer, sdpAnswer, nil } -func (p *Peer) Terminate() { +func (p *Peer[ID]) Terminate() { if err := p.peerConnection.Close(); err != nil { p.logger.WithError(err).Error("failed to close peer connection") } - p.notify <- LeftTheCall{Sender: p.id} + p.sink.Send(LeftTheCall{}) } -func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { +func (p *Peer[ID]) AddICECandidates(candidates []webrtc.ICECandidateInit) { for _, candidate := range candidates { if err := p.peerConnection.AddICECandidate(candidate); err != nil { p.logger.WithError(err).Error("failed to add ICE candidate") @@ -115,7 +107,7 @@ func (p *Peer) AddICECandidates(candidates []webrtc.ICECandidateInit) { } } -func (p *Peer) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { +func (p *Peer[ID]) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { _, err := p.peerConnection.AddTrack(track) if err != nil { p.logger.WithError(err).Error("failed to add track") @@ -125,7 +117,7 @@ func (p *Peer) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { return nil } -func (p *Peer) SendOverDataChannel(json string) error { +func (p *Peer[ID]) SendOverDataChannel(json string) error { p.dataChannelMutex.Lock() defer p.dataChannelMutex.Unlock() @@ -146,7 +138,7 @@ func (p *Peer) SendOverDataChannel(json string) error { return nil } -func (p *Peer) NewSDPAnswerReceived(sdpAnswer string) error { +func (p *Peer[ID]) NewSDPAnswerReceived(sdpAnswer string) error { err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeAnswer, SDP: sdpAnswer, diff --git a/src/peer/webrtc.go b/src/peer/webrtc.go index 889416c..46c54ef 100644 --- a/src/peer/webrtc.go +++ b/src/peer/webrtc.go @@ -11,7 +11,7 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. -func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { +func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval. // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI // when a viewer requests it. @@ -40,7 +40,7 @@ func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *web } // Notify others that our track has just been published. - p.notify <- NewTrackPublished{Sender: p.id, Track: localTrack} + p.sink.Send(NewTrackPublished{Track: localTrack}) // Start forwarding the data from the remote track to the local track, // so that everyone who is subscribed to this track will receive the data. @@ -56,31 +56,31 @@ func (p *Peer) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *web } else { // finished, no more data, but with error, inform others p.logger.WithError(readErr).Error("failed to read from remote track") } - p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + p.sink.Send(PublishedTrackFailed{Track: localTrack}) } // ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet. if _, err = localTrack.Write(rtpBuf[:index]); err != nil && !errors.Is(err, io.ErrClosedPipe) { p.logger.WithError(err).Error("failed to write to local track") - p.notify <- PublishedTrackFailed{Sender: p.id, Track: localTrack} + p.sink.Send(PublishedTrackFailed{Track: localTrack}) } } }() } // A callback that is called once we receive an ICE candidate for this peer connection. -func (p *Peer) onICECandidateGathered(candidate *webrtc.ICECandidate) { +func (p *Peer[ID]) onICECandidateGathered(candidate *webrtc.ICECandidate) { if candidate == nil { p.logger.Info("ICE candidate gathering finished") return } p.logger.WithField("candidate", candidate).Debug("ICE candidate gathered") - p.notify <- NewICECandidate{Sender: p.id, Candidate: candidate} + p.sink.Send(NewICECandidate{Candidate: candidate}) } // A callback that is called once we receive an ICE connection state change for this peer connection. -func (p *Peer) onNegotiationNeeded() { +func (p *Peer[ID]) onNegotiationNeeded() { p.logger.Debug("negotiation needed") offer, err := p.peerConnection.CreateOffer(nil) if err != nil { @@ -93,11 +93,11 @@ func (p *Peer) onNegotiationNeeded() { return } - p.notify <- RenegotiationRequired{Sender: p.id, Offer: &offer} + p.sink.Send(RenegotiationRequired{Offer: &offer}) } // A callback that is called once we receive an ICE connection state change for this peer connection. -func (p *Peer) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { +func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { p.logger.WithField("state", state).Debug("ICE connection state changed") switch state { @@ -111,35 +111,35 @@ func (p *Peer) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { // TODO: Ask Simon if we should do it here as in the previous implementation of the // `waterfall` or the way I did it in this new implementation. // p.notify <- PeerJoinedTheCall{sender: p.data} - p.notify <- ICEGatheringComplete{Sender: p.id} + p.sink.Send(ICEGatheringComplete{}) } } -func (p *Peer) onICEGatheringStateChanged(state webrtc.ICEGathererState) { +func (p *Peer[ID]) onICEGatheringStateChanged(state webrtc.ICEGathererState) { p.logger.WithField("state", state).Debug("ICE gathering state changed") if state == webrtc.ICEGathererStateComplete { - p.notify <- ICEGatheringComplete{Sender: p.id} + p.sink.Send(ICEGatheringComplete{}) } } -func (p *Peer) onSignalingStateChanged(state webrtc.SignalingState) { +func (p *Peer[ID]) onSignalingStateChanged(state webrtc.SignalingState) { p.logger.WithField("state", state).Debug("signaling state changed") } -func (p *Peer) onConnectionStateChanged(state webrtc.PeerConnectionState) { +func (p *Peer[ID]) onConnectionStateChanged(state webrtc.PeerConnectionState) { p.logger.WithField("state", state).Debug("connection state changed") switch state { case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: - p.notify <- LeftTheCall{Sender: p.id} + p.sink.Send(LeftTheCall{}) case webrtc.PeerConnectionStateConnected: - p.notify <- JoinedTheCall{Sender: p.id} + p.sink.Send(JoinedTheCall{}) } } // A callback that is called once the data channel is ready to be used. -func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { +func (p *Peer[ID]) onDataChannelReady(dc *webrtc.DataChannel) { p.dataChannelMutex.Lock() defer p.dataChannelMutex.Unlock() @@ -154,13 +154,13 @@ func (p *Peer) onDataChannelReady(dc *webrtc.DataChannel) { dc.OnOpen(func() { p.logger.Info("data channel opened") - p.notify <- DataChannelAvailable{Sender: p.id} + p.sink.Send(DataChannelAvailable{}) }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { p.logger.WithField("message", msg).Debug("data channel message received") if msg.IsString { - p.notify <- DataChannelMessage{Sender: p.id, Message: string(msg.Data)} + p.sink.Send(DataChannelMessage{Message: string(msg.Data)}) } else { p.logger.Warn("data channel message is not a string, ignoring") } diff --git a/src/router.go b/src/router.go index e7b3899..ec59de9 100644 --- a/src/router.go +++ b/src/router.go @@ -17,8 +17,7 @@ limitations under the License. package main import ( - "github.com/matrix-org/waterfall/src/conference" - "github.com/matrix-org/waterfall/src/peer" + conf "github.com/matrix-org/waterfall/src/conference" "github.com/matrix-org/waterfall/src/signaling" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" @@ -29,16 +28,16 @@ type Router struct { // Matrix matrix. matrix *signaling.MatrixClient // All calls currently forwarded by this SFU. - conferences map[string]*conference.Conference + conferences map[string]*conf.Conference // Configuration for the calls. - config conference.Config + config conf.Config } // Creates a new instance of the SFU with the given configuration. -func newRouter(matrix *signaling.MatrixClient, config conference.Config) *Router { +func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { return &Router{ matrix: matrix, - conferences: make(map[string]*conference.Conference), + conferences: make(map[string]*conf.Conference), config: config, } } @@ -63,17 +62,17 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - // If there is an invitation sent and the conf does not exist, create one. - if conf := r.conferences[invite.ConfID]; conf == nil { + // If there is an invitation sent and the conference does not exist, create one. + if conference := r.conferences[invite.ConfID]; conference == nil { logger.Infof("creating new conference %s", invite.ConfID) - r.conferences[invite.ConfID] = conference.NewConference( + r.conferences[invite.ConfID] = conf.NewConference( invite.ConfID, r.config, r.matrix.CreateForConference(invite.ConfID), ) } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: invite.DeviceID, } @@ -95,7 +94,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: candidates.DeviceID, } @@ -116,7 +115,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: selectAnswer.DeviceID, } @@ -137,7 +136,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - peerID := peer.ID{ + peerID := conf.ParticipantID{ UserID: evt.Sender, DeviceID: hangup.DeviceID, } diff --git a/src/signaling/client.go b/src/signaling/client.go new file mode 100644 index 0000000..a16313f --- /dev/null +++ b/src/signaling/client.go @@ -0,0 +1,67 @@ +package signaling + +import ( + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +type MatrixClient struct { + client *mautrix.Client +} + +func NewMatrixClient(config Config) *MatrixClient { + client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) + if err != nil { + logrus.WithError(err).Fatal("Failed to create client") + } + + whoami, err := client.Whoami() + if err != nil { + logrus.WithError(err).Fatal("Failed to identify SFU user") + } + + if config.UserID != whoami.UserID { + logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") + } + + logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") + client.DeviceID = whoami.DeviceID + + return &MatrixClient{ + client: client, + } +} + +// Starts the Matrix client and connects to the homeserver, +// Returns only when the sync with Matrix fails. +func (m *MatrixClient) RunSync(callback func(*event.Event)) { + syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) + if !ok { + logrus.Panic("Syncer is not DefaultSyncer") + } + + syncer.ParseEventContent = true + syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { + // We only care about to-device events. + if evt.Type.Class != event.ToDeviceEventType { + logrus.Warn("ignoring a not to-device event") + return + } + + // We drop the messages if they are not meant for us. + if evt.Content.Raw["dest_session_id"] != LocalSessionID { + logrus.Warn("SessionID does not match our SessionID - ignoring") + return + } + + callback(evt) + }) + + // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU + // as ending here will essentially drop all conferences which may not necessarily + // be what we want for the existing running conferences. + if err := m.client.Sync(); err != nil { + logrus.WithError(err).Panic("Sync failed") + } +} diff --git a/src/signaling/matrix.go b/src/signaling/matrix.go index 3c56a27..65c39d9 100644 --- a/src/signaling/matrix.go +++ b/src/signaling/matrix.go @@ -17,7 +17,6 @@ limitations under the License. package signaling import ( - "github.com/matrix-org/waterfall/src/peer" "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -26,66 +25,13 @@ import ( const LocalSessionID = "sfu" -type MatrixClient struct { - client *mautrix.Client -} - -func NewMatrixClient(config Config) *MatrixClient { - client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) - if err != nil { - logrus.WithError(err).Fatal("Failed to create client") - } - - whoami, err := client.Whoami() - if err != nil { - logrus.WithError(err).Fatal("Failed to identify SFU user") - } - - if config.UserID != whoami.UserID { - logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") - } - - logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") - client.DeviceID = whoami.DeviceID - - return &MatrixClient{ - client: client, - } -} - -// Starts the Matrix client and connects to the homeserver, -// Returns only when the sync with Matrix fails. -func (m *MatrixClient) RunSync(callback func(*event.Event)) { - syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) - if !ok { - logrus.Panic("Syncer is not DefaultSyncer") - } - - syncer.ParseEventContent = true - syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { - // We only care about to-device events. - if evt.Type.Class != event.ToDeviceEventType { - logrus.Warn("ignoring a not to-device event") - return - } - - // We drop the messages if they are not meant for us. - if evt.Content.Raw["dest_session_id"] != LocalSessionID { - logrus.Warn("SessionID does not match our SessionID - ignoring") - return - } - - callback(evt) - }) - - // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU - // as ending here will essentially drop all conferences which may not necessarily - // be what we want for the existing running conferences. - if err := m.client.Sync(); err != nil { - logrus.WithError(err).Panic("Sync failed") - } +// Matrix client scoped for a particular conference. +type MatrixForConference struct { + client *mautrix.Client + conferenceID string } +// Create a new Matrix client that abstarcts outgoing Matrix messages from a given conference. func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConference { return &MatrixForConference{ client: m.client, @@ -93,11 +39,14 @@ func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConfer } } +// Defines the data that identifies a receiver of Matrix's to-device message. type MatrixRecipient struct { - ID peer.ID + UserID id.UserID + DeviceID id.DeviceID RemoteSessionID id.SessionID } +// Interface that abstracts sending Send-to-device messages for the conference. type MatrixSignaling interface { SendSDPAnswer(recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, sdp string) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) @@ -105,11 +54,6 @@ type MatrixSignaling interface { SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) } -type MatrixForConference struct { - client *mautrix.Client - conferenceID string -} - func (m *MatrixForConference) SendSDPAnswer( recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, @@ -117,7 +61,7 @@ func (m *MatrixForConference) SendSDPAnswer( ) { eventContent := &event.Content{ Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Answer: event.CallData{ Type: "answer", SDP: sdp, @@ -126,40 +70,40 @@ func (m *MatrixForConference) SendSDPAnswer( }, } - m.sendToDevice(recipient.ID, event.CallAnswer, eventContent) + m.sendToDevice(recipient, event.CallAnswer, eventContent) } func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Candidates: candidates, }, } - m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) + m.sendToDevice(recipient, event.CallCandidates, eventContent) } func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Candidates: []event.CallCandidate{{Candidate: ""}}, }, } - m.sendToDevice(recipient.ID, event.CallCandidates, eventContent) + m.sendToDevice(recipient, event.CallCandidates, eventContent) } func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { eventContent := &event.Content{ Parsed: event.CallHangupEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.ID.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), Reason: reason, }, } - m.sendToDevice(recipient.ID, event.CallHangup, eventContent) + m.sendToDevice(recipient, event.CallHangup, eventContent) } func (m *MatrixForConference) createBaseEventContent( @@ -178,17 +122,17 @@ func (m *MatrixForConference) createBaseEventContent( } // Sends a to-device event to the given user. -func (m *MatrixForConference) sendToDevice(participantID peer.ID, eventType event.Type, eventContent *event.Content) { +func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event.Type, eventContent *event.Content) { // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, + "user_id": user.UserID, + "device_id": user.DeviceID, }) sendRequest := &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - participantID.UserID: { - participantID.DeviceID: eventContent, + user.UserID: { + user.DeviceID: eventContent, }, }, } From 532773f2952b4e3ad911d5e94867a9d38a28b53c Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 17:48:45 +0100 Subject: [PATCH 07/62] general: rename `src` to `pkg` This seems to be more idiomatic in Go projects. --- Dockerfile | 4 ++-- {src => pkg}/common/message_sink.go | 0 {src => pkg}/conference/conference.go | 6 +++--- {src => pkg}/conference/config.go | 0 {src => pkg}/conference/participant.go | 4 ++-- {src => pkg}/conference/processor.go | 4 ++-- {src => pkg}/config/config.go | 4 ++-- {src => pkg}/logger.go | 0 {src => pkg}/main.go | 4 ++-- {src => pkg}/peer/messages.go | 0 {src => pkg}/peer/peer.go | 2 +- {src => pkg}/peer/webrtc.go | 0 {src => pkg}/profile.go | 0 {src => pkg}/publisher.go | 0 {src => pkg}/router.go | 4 ++-- {src => pkg}/signaling/client.go | 0 {src => pkg}/signaling/config.go | 0 {src => pkg}/signaling/matrix.go | 0 {src => pkg}/subscriber.go | 0 scripts/build.sh | 2 +- scripts/profile.sh | 2 +- scripts/run.sh | 2 +- 22 files changed, 19 insertions(+), 19 deletions(-) rename {src => pkg}/common/message_sink.go (100%) rename {src => pkg}/conference/conference.go (96%) rename {src => pkg}/conference/config.go (100%) rename {src => pkg}/conference/participant.go (96%) rename {src => pkg}/conference/processor.go (98%) rename {src => pkg}/config/config.go (95%) rename {src => pkg}/logger.go (100%) rename {src => pkg}/main.go (96%) rename {src => pkg}/peer/messages.go (100%) rename {src => pkg}/peer/peer.go (99%) rename {src => pkg}/peer/webrtc.go (100%) rename {src => pkg}/profile.go (100%) rename {src => pkg}/publisher.go (100%) rename {src => pkg}/router.go (97%) rename {src => pkg}/signaling/client.go (100%) rename {src => pkg}/signaling/config.go (100%) rename {src => pkg}/signaling/matrix.go (100%) rename {src => pkg}/subscriber.go (100%) diff --git a/Dockerfile b/Dockerfile index b3bc9f6..b58d0f2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,9 @@ COPY go.sum ./ # source code do not invalidate our downloaded layer. RUN go mod download -COPY ./src ./src +COPY ./pkg ./pkg -RUN go build -o /waterfall ./src +RUN go build -o /waterfall ./pkg ## diff --git a/src/common/message_sink.go b/pkg/common/message_sink.go similarity index 100% rename from src/common/message_sink.go rename to pkg/common/message_sink.go diff --git a/src/conference/conference.go b/pkg/conference/conference.go similarity index 96% rename from src/conference/conference.go rename to pkg/conference/conference.go index 56ef684..901906e 100644 --- a/src/conference/conference.go +++ b/pkg/conference/conference.go @@ -17,9 +17,9 @@ limitations under the License. package conference import ( - "github.com/matrix-org/waterfall/src/common" - "github.com/matrix-org/waterfall/src/peer" - "github.com/matrix-org/waterfall/src/signaling" + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" diff --git a/src/conference/config.go b/pkg/conference/config.go similarity index 100% rename from src/conference/config.go rename to pkg/conference/config.go diff --git a/src/conference/participant.go b/pkg/conference/participant.go similarity index 96% rename from src/conference/participant.go rename to pkg/conference/participant.go index 08b5f88..835fe85 100644 --- a/src/conference/participant.go +++ b/pkg/conference/participant.go @@ -4,8 +4,8 @@ import ( "encoding/json" "errors" - "github.com/matrix-org/waterfall/src/peer" - "github.com/matrix-org/waterfall/src/signaling" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" diff --git a/src/conference/processor.go b/pkg/conference/processor.go similarity index 98% rename from src/conference/processor.go rename to pkg/conference/processor.go index c401079..b3f0a29 100644 --- a/src/conference/processor.go +++ b/pkg/conference/processor.go @@ -4,8 +4,8 @@ import ( "encoding/json" "errors" - "github.com/matrix-org/waterfall/src/common" - "github.com/matrix-org/waterfall/src/peer" + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" "maunium.net/go/mautrix/event" ) diff --git a/src/config/config.go b/pkg/config/config.go similarity index 95% rename from src/config/config.go rename to pkg/config/config.go index d9adc83..d908c34 100644 --- a/src/config/config.go +++ b/pkg/config/config.go @@ -5,8 +5,8 @@ import ( "fmt" "os" - "github.com/matrix-org/waterfall/src/conference" - "github.com/matrix-org/waterfall/src/signaling" + "github.com/matrix-org/waterfall/pkg/conference" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) diff --git a/src/logger.go b/pkg/logger.go similarity index 100% rename from src/logger.go rename to pkg/logger.go diff --git a/src/main.go b/pkg/main.go similarity index 96% rename from src/main.go rename to pkg/main.go index cfffc62..855ba1d 100644 --- a/src/main.go +++ b/pkg/main.go @@ -22,8 +22,8 @@ import ( "os/signal" "syscall" - "github.com/matrix-org/waterfall/src/config" - "github.com/matrix-org/waterfall/src/signaling" + "github.com/matrix-org/waterfall/pkg/config" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" ) diff --git a/src/peer/messages.go b/pkg/peer/messages.go similarity index 100% rename from src/peer/messages.go rename to pkg/peer/messages.go diff --git a/src/peer/peer.go b/pkg/peer/peer.go similarity index 99% rename from src/peer/peer.go rename to pkg/peer/peer.go index 917682b..960b405 100644 --- a/src/peer/peer.go +++ b/pkg/peer/peer.go @@ -4,7 +4,7 @@ import ( "errors" "sync" - "github.com/matrix-org/waterfall/src/common" + "github.com/matrix-org/waterfall/pkg/common" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" ) diff --git a/src/peer/webrtc.go b/pkg/peer/webrtc.go similarity index 100% rename from src/peer/webrtc.go rename to pkg/peer/webrtc.go diff --git a/src/profile.go b/pkg/profile.go similarity index 100% rename from src/profile.go rename to pkg/profile.go diff --git a/src/publisher.go b/pkg/publisher.go similarity index 100% rename from src/publisher.go rename to pkg/publisher.go diff --git a/src/router.go b/pkg/router.go similarity index 97% rename from src/router.go rename to pkg/router.go index ec59de9..11bf567 100644 --- a/src/router.go +++ b/pkg/router.go @@ -17,8 +17,8 @@ limitations under the License. package main import ( - conf "github.com/matrix-org/waterfall/src/conference" - "github.com/matrix-org/waterfall/src/signaling" + conf "github.com/matrix-org/waterfall/pkg/conference" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" ) diff --git a/src/signaling/client.go b/pkg/signaling/client.go similarity index 100% rename from src/signaling/client.go rename to pkg/signaling/client.go diff --git a/src/signaling/config.go b/pkg/signaling/config.go similarity index 100% rename from src/signaling/config.go rename to pkg/signaling/config.go diff --git a/src/signaling/matrix.go b/pkg/signaling/matrix.go similarity index 100% rename from src/signaling/matrix.go rename to pkg/signaling/matrix.go diff --git a/src/subscriber.go b/pkg/subscriber.go similarity index 100% rename from src/subscriber.go rename to pkg/subscriber.go diff --git a/scripts/build.sh b/scripts/build.sh index 9702abb..56595bf 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go build -o dist/waterfall ./src +go build -o dist/waterfall ./pkg diff --git a/scripts/profile.sh b/scripts/profile.sh index 071103e..0042a33 100755 --- a/scripts/profile.sh +++ b/scripts/profile.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./src/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof --logTime +go run ./pkg/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof --logTime diff --git a/scripts/run.sh b/scripts/run.sh index 50c2f32..16e9b1f 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./src --logTime +go run ./pkg --logTime From 4d7970cee16c22a373a1fc43322dd90c158c20dc Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 21:12:30 +0100 Subject: [PATCH 08/62] conference: define sane logic for new participants New participants are now properly handled depending on whether or not they were in the session. Also, the conference and peer get documented. --- pkg/common/message_sink.go | 25 +++++- pkg/conference/conference.go | 149 +++++++++++++++++++++++++--------- pkg/conference/participant.go | 50 +----------- pkg/conference/processor.go | 10 +-- pkg/peer/messages.go | 2 + pkg/peer/peer.go | 106 ++++++++++++++---------- 6 files changed, 208 insertions(+), 134 deletions(-) diff --git a/pkg/common/message_sink.go b/pkg/common/message_sink.go index 3351492..1c7604f 100644 --- a/pkg/common/message_sink.go +++ b/pkg/common/message_sink.go @@ -1,5 +1,10 @@ package common +import ( + "errors" + "sync/atomic" +) + // MessageSink is a helper struct that allows to send messages to a message sink. // The MessageSink abstracts the message sink which has a certain sender, so that // the sender does not have to be specified every time a message is sent. @@ -10,6 +15,10 @@ type MessageSink[SenderType comparable, MessageType any] struct { sender SenderType // The message sink to which the messages are sent. messageSink chan<- Message[SenderType, MessageType] + // Atomic variable that indicates whether the message sink is sealed. + // This is used to prevent sending messages to a sealed message sink. + // The variable is atomic because it may be accessed from multiple goroutines. + sealed atomic.Bool } // Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. @@ -21,11 +30,25 @@ func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, } // Sends a message to the message sink. -func (s *MessageSink[S, M]) Send(message M) { +func (s *MessageSink[S, M]) Send(message M) error { + if s.sealed.Load() { + return errors.New("The channel is sealed, you can't send any messages over it") + } + s.messageSink <- Message[S, M]{ Sender: s.sender, Content: message, } + + return nil +} + +// Seals the channel, which means that no messages could be sent via this channel. +// Any attempt to send a message would result in an error. This is similar to closing the +// channel except that we don't close the underlying channel (since there might be other +// senders that may want to use it). +func (s *MessageSink[S, M]) Seal() { + s.sealed.Store(true) } // Messages that are sent from the peer to the conference in order to communicate with other peers. diff --git a/pkg/conference/conference.go b/pkg/conference/conference.go index 901906e..9132964 100644 --- a/pkg/conference/conference.go +++ b/pkg/conference/conference.go @@ -25,6 +25,7 @@ import ( "maunium.net/go/mautrix/event" ) +// A single conference. Call and conference mean the same in context of Matrix. type Conference struct { id string config Config @@ -51,53 +52,66 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna // New participant tries to join the conference. func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) { + logger := logrus.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + "conf_id": c.id, + }) + // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. - // TODO: Implement this. - for id, participant := range c.participants { - if id.DeviceID == inviteEvent.DeviceID { - if participant.remoteSessionID == inviteEvent.SenderSessionID { - c.logger.WithFields(logrus.Fields{ - "device_id": inviteEvent.DeviceID, - "session_id": inviteEvent.SenderSessionID, - }).Errorf("Found existing participant with equal DeviceID and SessionID") - return - } else { - participant.peer.Terminate() - } + participant := c.getParticipant(participantID, nil) + if participant != nil { + if participant.remoteSessionID == inviteEvent.SenderSessionID { + c.logger.Errorf("Found existing participant with equal DeviceID and SessionID") + } else { + c.removeParticipant(participantID) + participant = nil } } - var ( - participantlogger = logrus.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - "conf_id": c.id, - }) - messageSink = common.NewMessageSink(participantID, c.peerEvents) - ) + // If participant exists still exists, then it means that the client does not behave properly. + // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. + sdpAnswer, err := func() (*webrtc.SessionDescription, error) { + if participant == nil { + messageSink := common.NewMessageSink(participantID, c.peerEvents) - peer, sdpOffer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, participantlogger) + peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) + if err != nil { + return nil, err + } + + participant = &Participant{ + id: participantID, + peer: peer, + remoteSessionID: inviteEvent.SenderSessionID, + streamMetadata: inviteEvent.SDPStreamMetadata, + publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), + } + + c.participants[participantID] = participant + return answer, nil + } else { + answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) + if err != nil { + return nil, err + } + return answer, nil + } + }() if err != nil { - c.logger.WithError(err).Errorf("Failed to create new peer") + logger.WithError(err).Errorf("Failed to process SDP offer") return } - participant := &Participant{ - id: participantID, - peer: peer, - remoteSessionID: inviteEvent.SenderSessionID, - streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), - } - - c.participants[participantID] = participant - + // Send the answer back to the remote peer. recipient := participant.asMatrixRecipient() - streamMetadata := c.getStreamsMetadata(participantID) - c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpOffer.SDP) + streamMetadata := c.getAvailableStreamsFor(participantID) + c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpAnswer.SDP) } +// Process new ICE candidates received from Matrix signaling (from the remote peer) and forward them to +// our internal peer connection. func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { if participant := c.getParticipant(participantID, nil); participant != nil { // Convert the candidates to the WebRTC format. @@ -111,10 +125,12 @@ func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCan } } - participant.peer.AddICECandidates(candidates) + participant.peer.ProcessNewRemoteCandidates(candidates) } } +// Process an acknowledgement from the remote peer that the SDP answer has been received +// and that the call can now proceed. func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { if participant := c.getParticipant(participantID, nil); participant != nil { if ev.SelectedPartyID != participantID.DeviceID.String() { @@ -126,8 +142,67 @@ func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallS } } +// Process a message from the remote peer telling that it wants to hang up the call. func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { - if participant := c.getParticipant(participantID, nil); participant != nil { - participant.peer.Terminate() + c.removeParticipant(participantID) +} + +func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[participantID] + if !ok { + logEntry := c.logger.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + if optionalErrorMessage != nil { + logEntry.WithError(optionalErrorMessage) + } else { + logEntry.Error("Participant not found") + } + + return nil + } + + return participant +} + +// Helper to terminate and remove a participant from the conference. +func (c *Conference) removeParticipant(participantID ParticipantID) { + participant := c.getParticipant(participantID, nil) + if participant == nil { + return + } + + participant.peer.Terminate() + delete(c.participants, participantID) +} + +// Helper to get the list of available streams for a given participant, i.e. the list of streams +// that a given participant **can subscribe to**. +func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.CallSDPStreamMetadata { + streamsMetadata := make(event.CallSDPStreamMetadata) + for id, participant := range c.participants { + if forParticipant != id { + for streamID, metadata := range participant.streamMetadata { + streamsMetadata[streamID] = metadata + } + } + } + + return streamsMetadata +} + +// Helper that returns the list of streams inside this conference that match the given stream IDs. +func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { + tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) + for _, participant := range c.participants { + // Check if this participant has any of the tracks that we're looking for. + for _, identifier := range identifiers { + if track, ok := participant.publishedTracks[identifier]; ok { + tracks = append(tracks, track) + } + } } + return tracks } diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 835fe85..f7565fb 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -7,18 +7,20 @@ import ( "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ErrInvalidSFUMessage = errors.New("invalid SFU message") +// Things that we assume as identifiers for the participants in the call. +// There could be no 2 participants in the room with identical IDs. type ParticipantID struct { UserID id.UserID DeviceID id.DeviceID } +// Participant represents a participant in the conference. type Participant struct { id ParticipantID peer *peer.Peer[ParticipantID] @@ -49,49 +51,3 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error { return nil } - -func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { - participant, ok := c.participants[participantID] - if !ok { - logEntry := c.logger.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - }) - - if optionalErrorMessage != nil { - logEntry.WithError(optionalErrorMessage) - } else { - logEntry.Error("Participant not found") - } - - return nil - } - - return participant -} - -func (c *Conference) getStreamsMetadata(forParticipant ParticipantID) event.CallSDPStreamMetadata { - streamsMetadata := make(event.CallSDPStreamMetadata) - for id, participant := range c.participants { - if forParticipant != id { - for streamID, metadata := range participant.streamMetadata { - streamsMetadata[streamID] = metadata - } - } - } - - return streamsMetadata -} - -func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { - tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) - for _, participant := range c.participants { - // Check if this participant has any of the tracks that we're looking for. - for _, identifier := range identifiers { - if track, ok := participant.publishedTracks[identifier]; ok { - tracks = append(tracks, track) - } - } - } - return tracks -} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index b3f0a29..89db48d 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -71,7 +71,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe toSend := event.SFUMessage{ Op: event.SFUOperationOffer, SDP: msg.Offer.SDP, - Metadata: c.getStreamsMetadata(participant.id), + Metadata: c.getAvailableStreamsFor(participant.id), } participant.sendDataChannelMessage(toSend) @@ -88,7 +88,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe case peer.DataChannelAvailable: toSend := event.SFUMessage{ Op: event.SFUOperationMetadata, - Metadata: c.getStreamsMetadata(participant.id), + Metadata: c.getAvailableStreamsFor(participant.id), } if err := participant.sendDataChannelMessage(toSend); err != nil { @@ -107,14 +107,14 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa case event.SFUOperationSelect: // Get the tracks that correspond to the tracks that the participant wants to receive. for _, track := range c.getTracks(sfuMessage.Start) { - if err := participant.peer.SubscribeToTrack(track); err != nil { + if err := participant.peer.SubscribeTo(track); err != nil { c.logger.Errorf("Failed to subscribe to track: %v", err) return } } case event.SFUOperationAnswer: - if err := participant.peer.NewSDPAnswerReceived(sfuMessage.SDP); err != nil { + if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { c.logger.Errorf("Failed to set SDP answer: %v", err) return } @@ -137,7 +137,7 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa toSend := event.SFUMessage{ Op: event.SFUOperationMetadata, - Metadata: c.getStreamsMetadata(id), + Metadata: c.getAvailableStreamsFor(id), } if err := participant.sendDataChannelMessage(toSend); err != nil { diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 51ef1d9..7d8adf8 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -4,6 +4,8 @@ import ( "github.com/pion/webrtc/v3" ) +// Due to the limitation of Go, we're using the `interface{}` to be able to use switch the actual +// type of the message on runtime. The underlying types do not necessary need to be structures. type MessageContent = interface{} type JoinedTheCall struct{} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 960b405..8177d5e 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -20,6 +20,10 @@ var ( ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) +// A wrapped representation of the peer connection (single peer in the call). +// The peer gets information about the things happening outside via public methods +// and informs the outside world about the things happening inside the peer by posting +// the messages to the channel. type Peer[ID comparable] struct { logger *logrus.Entry peerConnection *webrtc.PeerConnection @@ -29,6 +33,7 @@ type Peer[ID comparable] struct { dataChannel *webrtc.DataChannel } +// Instantiates a new peer with a given SDP offer and returns a peer and the SDP answer if everything is ok. func NewPeer[ID comparable]( sdpOffer string, sink *common.MessageSink[ID, MessageContent], @@ -55,59 +60,27 @@ func NewPeer[ID comparable]( peerConnection.OnConnectionStateChange(peer.onConnectionStateChanged) peerConnection.OnSignalingStateChange(peer.onSignalingStateChanged) - err = peerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdpOffer, - }) - if err != nil { - logger.WithError(err).Error("failed to set remote description") - peerConnection.Close() - return nil, nil, ErrCantSetRemoteDecsription - } - - answer, err := peerConnection.CreateAnswer(nil) - if err != nil { - logger.WithError(err).Error("failed to create answer") - peerConnection.Close() - return nil, nil, ErrCantCreateAnswer - } - - if err := peerConnection.SetLocalDescription(answer); err != nil { - logger.WithError(err).Error("failed to set local description") - peerConnection.Close() - return nil, nil, ErrCantSetLocalDescription - } - - // TODO: Do we really need to call `webrtc.GatheringCompletePromise` - // as in the previous version of the `waterfall` here? - - sdpAnswer := peerConnection.LocalDescription() - if sdpAnswer == nil { - logger.WithError(err).Error("could not generate a local description") - peerConnection.Close() - return nil, nil, ErrCantCreateLocalDescription + if sdpAnswer, err := peer.ProcessSDPOffer(sdpOffer); err != nil { + return nil, nil, err + } else { + return peer, sdpAnswer, nil } - - return peer, sdpAnswer, nil } +// Closes peer connection. From this moment on, no new messages will be sent from the peer. func (p *Peer[ID]) Terminate() { if err := p.peerConnection.Close(); err != nil { p.logger.WithError(err).Error("failed to close peer connection") } - p.sink.Send(LeftTheCall{}) + // We want to seal the channel since the sender is not interested in us anymore. + // We may want to remove this logic if/once we want to receive messages (confirmation of close or whatever) + // from the peer that is considered closed. + p.sink.Seal() } -func (p *Peer[ID]) AddICECandidates(candidates []webrtc.ICECandidateInit) { - for _, candidate := range candidates { - if err := p.peerConnection.AddICECandidate(candidate); err != nil { - p.logger.WithError(err).Error("failed to add ICE candidate") - } - } -} - -func (p *Peer[ID]) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { +// Add given track to our peer connection, so that it can be sent to the remote peer. +func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { _, err := p.peerConnection.AddTrack(track) if err != nil { p.logger.WithError(err).Error("failed to add track") @@ -117,6 +90,7 @@ func (p *Peer[ID]) SubscribeToTrack(track *webrtc.TrackLocalStaticRTP) error { return nil } +// Tries to send the given message to the remote counterpart of our peer. func (p *Peer[ID]) SendOverDataChannel(json string) error { p.dataChannelMutex.Lock() defer p.dataChannelMutex.Unlock() @@ -138,7 +112,17 @@ func (p *Peer[ID]) SendOverDataChannel(json string) error { return nil } -func (p *Peer[ID]) NewSDPAnswerReceived(sdpAnswer string) error { +// Processes the remote ICE candidates. +func (p *Peer[ID]) ProcessNewRemoteCandidates(candidates []webrtc.ICECandidateInit) { + for _, candidate := range candidates { + if err := p.peerConnection.AddICECandidate(candidate); err != nil { + p.logger.WithError(err).Error("failed to add ICE candidate") + } + } +} + +// Processes the SDP answer received from the remote peer. +func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeAnswer, SDP: sdpAnswer, @@ -150,3 +134,37 @@ func (p *Peer[ID]) NewSDPAnswerReceived(sdpAnswer string) error { return nil } + +// Applies the sdp offer received from the remote peer and generates an SDP answer. +func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, error) { + err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + }) + if err != nil { + p.logger.WithError(err).Error("failed to set remote description") + return nil, ErrCantSetRemoteDecsription + } + + answer, err := p.peerConnection.CreateAnswer(nil) + if err != nil { + p.logger.WithError(err).Error("failed to create answer") + return nil, ErrCantCreateAnswer + } + + if err := p.peerConnection.SetLocalDescription(answer); err != nil { + p.logger.WithError(err).Error("failed to set local description") + return nil, ErrCantSetLocalDescription + } + + // TODO: Do we really need to call `webrtc.GatheringCompletePromise` + // as in the previous version of the `waterfall` here? + + sdpAnswer := p.peerConnection.LocalDescription() + if sdpAnswer == nil { + p.logger.WithError(err).Error("could not generate a local description") + return nil, ErrCantCreateLocalDescription + } + + return sdpAnswer, nil +} From 83a49b5dccd9e1ac08cfb3af5a53d105a24fc5a9 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 21:59:52 +0100 Subject: [PATCH 09/62] conference: handle publish messages from peers --- pkg/conference/conference.go | 5 ++-- pkg/conference/participant.go | 14 ++++------- pkg/conference/processor.go | 47 ++++++++++++++++++----------------- pkg/peer/peer.go | 3 +-- pkg/peer/webrtc.go | 8 ++---- 5 files changed, 35 insertions(+), 42 deletions(-) diff --git a/pkg/conference/conference.go b/pkg/conference/conference.go index 9132964..880f5f6 100644 --- a/pkg/conference/conference.go +++ b/pkg/conference/conference.go @@ -52,10 +52,9 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna // New participant tries to join the conference. func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) { - logger := logrus.WithFields(logrus.Fields{ + logger := c.logger.WithFields(logrus.Fields{ "user_id": participantID.UserID, "device_id": participantID.DeviceID, - "conf_id": c.id, }) // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, @@ -84,6 +83,7 @@ func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent * participant = &Participant{ id: participantID, peer: peer, + logger: logger, remoteSessionID: inviteEvent.SenderSessionID, streamMetadata: inviteEvent.SDPStreamMetadata, publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), @@ -136,6 +136,7 @@ func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallS if ev.SelectedPartyID != participantID.DeviceID.String() { c.logger.WithFields(logrus.Fields{ "device_id": ev.SelectedPartyID, + "user_id": participantID, }).Errorf("Call was answered on a different device, kicking this peer") participant.peer.Terminate() } diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index f7565fb..a7b9385 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -2,17 +2,15 @@ package conference import ( "encoding/json" - "errors" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -var ErrInvalidSFUMessage = errors.New("invalid SFU message") - // Things that we assume as identifiers for the participants in the call. // There could be no 2 participants in the room with identical IDs. type ParticipantID struct { @@ -23,6 +21,7 @@ type ParticipantID struct { // Participant represents a participant in the conference. type Participant struct { id ParticipantID + logger *logrus.Entry peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata @@ -37,17 +36,14 @@ func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { } } -func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) error { +func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) { jsonToSend, err := json.Marshal(toSend) if err != nil { - return ErrInvalidSFUMessage + p.logger.Error("Failed to marshal data channel message") } if err := p.peer.SendOverDataChannel(string(jsonToSend)); err != nil { // FIXME: We must buffer the message in this case and re-send it once the data channel is recovered! - // Or use Matrix signaling to inform the peer about the problem. - return err + p.logger.Error("Failed to send data channel message") } - - return nil } diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 89db48d..f4f38b5 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -11,12 +11,13 @@ import ( func (c *Conference) processMessages() { for { - // Read a message from the stream (of type peer.Message) and process it. + // Read a message from the participant in the room (our local counterpart of it) message := <-c.peerEvents c.processPeerMessage(message) } } +// Process a message from a local peer. func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) if participant == nil { @@ -86,15 +87,10 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.handleDataChannelMessage(participant, sfuMessage) case peer.DataChannelAvailable: - toSend := event.SFUMessage{ + participant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationMetadata, Metadata: c.getAvailableStreamsFor(participant.id), - } - - if err := participant.sendDataChannelMessage(toSend); err != nil { - c.logger.Errorf("Failed to send SFU message to open data channel: %v", err) - return - } + }) default: c.logger.Errorf("Unknown message type: %T", msg) @@ -108,42 +104,47 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa // Get the tracks that correspond to the tracks that the participant wants to receive. for _, track := range c.getTracks(sfuMessage.Start) { if err := participant.peer.SubscribeTo(track); err != nil { - c.logger.Errorf("Failed to subscribe to track: %v", err) + participant.logger.Errorf("Failed to subscribe to track: %v", err) return } } case event.SFUOperationAnswer: if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { - c.logger.Errorf("Failed to set SDP answer: %v", err) + participant.logger.Errorf("Failed to set SDP answer: %v", err) return } - // TODO: Clarify the semantics of publish (just a new sdp offer?). case event.SFUOperationPublish: - // TODO: Clarify the semantics of publish (how is it different from unpublish?). + answer, err := participant.peer.ProcessSDPOffer(sfuMessage.SDP) + if err != nil { + participant.logger.Errorf("Failed to set SDP offer: %v", err) + return + } + + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationAnswer, + SDP: answer.SDP, + }) + case event.SFUOperationUnpublish: - // TODO: Handle the heartbeat message here (updating the last timestamp etc). + // TODO: Clarify the semantics of unpublish. case event.SFUOperationAlive: + // TODO: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: participant.streamMetadata = sfuMessage.Metadata // Inform all participants about new metadata available. - for id, participant := range c.participants { + for _, otherParticipant := range c.participants { // Skip ourselves. - if id == participant.id { + if participant.id == otherParticipant.id { continue } - toSend := event.SFUMessage{ + otherParticipant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(id), - } - - if err := participant.sendDataChannelMessage(toSend); err != nil { - c.logger.Errorf("Failed to send SFU message: %v", err) - return - } + Metadata: c.getAvailableStreamsFor(otherParticipant.id), + }) } } } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 8177d5e..9376a04 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -157,8 +157,7 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return nil, ErrCantSetLocalDescription } - // TODO: Do we really need to call `webrtc.GatheringCompletePromise` - // as in the previous version of the `waterfall` here? + // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? sdpAnswer := p.peerConnection.LocalDescription() if sdpAnswer == nil { diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 46c54ef..461dd48 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -49,7 +49,6 @@ func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver for { index, _, readErr := remoteTrack.Read(rtpBuf) - // TODO: inform the conference that this publisher's track is not available anymore. if readErr != nil { if readErr == io.EOF { // finished, no more data, no error, inform others p.logger.Info("remote track closed") @@ -100,18 +99,15 @@ func (p *Peer[ID]) onNegotiationNeeded() { func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { p.logger.WithField("state", state).Debug("ICE connection state changed") + // TODO: Ask Simon if we should do it here as in the previous implementation of the + // `waterfall` or the way I did it in this new implementation. switch state { case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateDisconnected: // TODO: We may want to treat it as an opportunity for the ICE restart instead. - // TODO: Ask Simon if we should do it here as in the previous implementation of the - // `waterfall` or the way I did it in this new implementation. // p.notify <- PeerLeftTheCall{sender: p.data} case webrtc.ICEConnectionStateCompleted, webrtc.ICEConnectionStateConnected: // TODO: Start keep-alive timer over the data channel to check the connecitons that hanged. - // TODO: Ask Simon if we should do it here as in the previous implementation of the - // `waterfall` or the way I did it in this new implementation. // p.notify <- PeerJoinedTheCall{sender: p.data} - p.sink.Send(ICEGatheringComplete{}) } } From ce324819a1daf7cea15566ec5062e340e01f51a0 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 22:40:29 +0100 Subject: [PATCH 10/62] paer: unsubscribe from obsolete tracks Once the track become obsolete, we unsubscribe from them. --- go.mod | 3 ++- go.sum | 5 ++++- pkg/conference/conference.go | 25 ++++++++++++++++++++++++ pkg/conference/processor.go | 38 +++++++++++++++--------------------- pkg/peer/peer.go | 18 ++++++++++++++++- pkg/peer/webrtc.go | 5 ++--- pkg/signaling/matrix.go | 1 - 7 files changed, 66 insertions(+), 29 deletions(-) diff --git a/go.mod b/go.mod index 67dcfa9..adaf0b3 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require github.com/pion/webrtc/v3 v3.1.31 require ( github.com/pion/rtcp v1.2.9 github.com/sirupsen/logrus v1.9.0 + golang.org/x/exp v0.0.0-20221114191408-850992195362 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mautrix v0.11.0 ) @@ -34,7 +35,7 @@ require ( github.com/tidwall/sjson v1.2.4 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect - golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect + golang.org/x/sys v0.1.0 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect ) diff --git a/go.sum b/go.sum index 17aeab1..e02475c 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20221114191408-850992195362 h1:NoHlPRbyl1VFI6FjwHtPQCN7wAMXI6cKcqrmXhOOfBQ= +golang.org/x/exp v0.0.0-20221114191408-850992195362/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -134,8 +136,9 @@ golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/conference/conference.go b/pkg/conference/conference.go index 880f5f6..0e8cd36 100644 --- a/pkg/conference/conference.go +++ b/pkg/conference/conference.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" ) @@ -175,8 +176,20 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { return } + // Terminate the participant and remove it from the list. participant.peer.Terminate() delete(c.participants, participantID) + + // Inform the other participants about updated metadata (since the participant left + // the corresponding streams of the participant are no longer available, so we're informing + // others about it). + c.resendMetadataToAllExcept(participantID) + + // Remove the participant's tracks from all participants who might have subscribed to them. + obsoleteTracks := maps.Values(participant.publishedTracks) + for _, otherParticipant := range c.participants { + otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) + } } // Helper to get the list of available streams for a given participant, i.e. the list of streams @@ -207,3 +220,15 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt } return tracks } + +// Helper that sends current metadata about all available tracks to all participants except a given one. +func (c *Conference) resendMetadataToAllExcept(exceptMe ParticipantID) { + for participantID, participant := range c.participants { + if participantID != exceptMe { + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participantID), + }) + } + } +} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index f4f38b5..b838f74 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -6,6 +6,7 @@ import ( "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -28,10 +29,11 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe // determine the actual type of the message. switch msg := message.Content.(type) { case peer.JoinedTheCall: + c.resendMetadataToAllExcept(participant.id) + case peer.LeftTheCall: - delete(c.participants, message.Sender) - // TODO: Send new metadata about available streams to all participants. - // TODO: Send the hangup event over the Matrix back to the user. + c.removeParticipant(message.Sender) + c.signaling.SendHangup(participant.asMatrixRecipient(), event.CallHangupUnknownError) case peer.NewTrackPublished: key := event.SFUTrackDescription{ @@ -52,7 +54,13 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe TrackID: msg.Track.ID(), }) - // TODO: Should we remove the local tracks from every subscriber as well? Or will it happen automatically? + for _, otherParticipant := range c.participants { + if otherParticipant.id == participant.id { + continue + } + + otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) + } case peer.NewICECandidate: // Convert WebRTC ICE candidate to Matrix ICE candidate. @@ -69,13 +77,11 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) case peer.RenegotiationRequired: - toSend := event.SFUMessage{ + participant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationOffer, SDP: msg.Offer.SDP, Metadata: c.getAvailableStreamsFor(participant.id), - } - - participant.sendDataChannelMessage(toSend) + }) case peer.DataChannelMessage: var sfuMessage event.SFUMessage @@ -130,21 +136,9 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa case event.SFUOperationUnpublish: // TODO: Clarify the semantics of unpublish. case event.SFUOperationAlive: - // TODO: Handle the heartbeat message here (updating the last timestamp etc). + // FIXME: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: participant.streamMetadata = sfuMessage.Metadata - - // Inform all participants about new metadata available. - for _, otherParticipant := range c.participants { - // Skip ourselves. - if participant.id == otherParticipant.id { - continue - } - - otherParticipant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(otherParticipant.id), - }) - } + c.resendMetadataToAllExcept(participant.id) } } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 9376a04..eaafdfa 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -79,7 +79,7 @@ func (p *Peer[ID]) Terminate() { p.sink.Seal() } -// Add given track to our peer connection, so that it can be sent to the remote peer. +// Adds given track to our peer connection, so that it can be sent to the remote peer. func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { _, err := p.peerConnection.AddTrack(track) if err != nil { @@ -90,6 +90,22 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { return nil } +// Unsubscribes from the given list of tracks. +func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { + // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. + for _, presentTrack := range p.peerConnection.GetSenders() { + for _, trackToUnsubscribe := range tracks { + presentTrackID, presentStreamID := presentTrack.Track().ID(), presentTrack.Track().StreamID() + trackID, streamID := trackToUnsubscribe.ID(), trackToUnsubscribe.StreamID() + if presentTrackID == trackID && presentStreamID == streamID { + if err := p.peerConnection.RemoveTrack(presentTrack); err != nil { + p.logger.WithError(err).Error("failed to remove track") + } + } + } + } +} + // Tries to send the given message to the remote counterpart of our peer. func (p *Peer[ID]) SendOverDataChannel(json string) error { p.dataChannelMutex.Lock() diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 461dd48..741f135 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -99,14 +99,13 @@ func (p *Peer[ID]) onNegotiationNeeded() { func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { p.logger.WithField("state", state).Debug("ICE connection state changed") - // TODO: Ask Simon if we should do it here as in the previous implementation of the - // `waterfall` or the way I did it in this new implementation. + // TODO: Ask Simon if we should do it here as in the previous implementation. switch state { case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateDisconnected: // TODO: We may want to treat it as an opportunity for the ICE restart instead. // p.notify <- PeerLeftTheCall{sender: p.data} case webrtc.ICEConnectionStateCompleted, webrtc.ICEConnectionStateConnected: - // TODO: Start keep-alive timer over the data channel to check the connecitons that hanged. + // FIXME: Start keep-alive timer over the data channel to check the connecitons that hanged. // p.notify <- PeerJoinedTheCall{sender: p.data} } } diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 65c39d9..0c84331 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -123,7 +123,6 @@ func (m *MatrixForConference) createBaseEventContent( // Sends a to-device event to the given user. func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event.Type, eventContent *event.Content) { - // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ "user_id": user.UserID, "device_id": user.DeviceID, From b47600cda85575dcedd0178615632dfaf3f5305c Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 24 Nov 2022 22:44:47 +0100 Subject: [PATCH 11/62] peer: add naive handling of RTCPs This is a temporary solution. --- pkg/peer/peer.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index eaafdfa..a6ef022 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -81,12 +81,24 @@ func (p *Peer[ID]) Terminate() { // Adds given track to our peer connection, so that it can be sent to the remote peer. func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { - _, err := p.peerConnection.AddTrack(track) + rtpSender, err := p.peerConnection.AddTrack(track) if err != nil { p.logger.WithError(err).Error("failed to add track") return ErrCantSubscribeToTrack } + // Read incoming RTCP packets + // Before these packets are returned they are processed by interceptors. For things + // like NACK this needs to be called. + go func() { + rtcpBuf := make([]byte, 1500) + for { + if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { + return + } + } + }() + return nil } From a8312b0b4392d1d268bf49ac296e749219248de1 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 25 Nov 2022 18:09:15 +0100 Subject: [PATCH 12/62] router: remove useles checks for types It turns out that currently `mautrix` just simply returns a pointer to an empty structure in this case (which is probably not the behavior that we want to have!), but because of this there is no point to check the types. --- pkg/router.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/pkg/router.go b/pkg/router.go index 11bf567..842a78a 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -43,8 +43,6 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { } // Handles incoming To-Device events that the SFU receives from clients. -// -//nolint:funlen func (r *Router) handleMatrixEvent(evt *event.Event) { // TODO: Don't create logger again and again, it might be a bit expensive. logger := logrus.WithFields(logrus.Fields{ @@ -57,10 +55,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Someone tries to participate in a call (join a call). case event.ToDeviceCallInvite.Type: invite := evt.Content.AsCallInvite() - if invite == nil { - logger.Error("failed to parse invite") - return - } // If there is an invitation sent and the conference does not exist, create one. if conference := r.conferences[invite.ConfID]; conference == nil { @@ -83,10 +77,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Someone tries to send ICE candidates to the existing call. case event.ToDeviceCallCandidates.Type: candidates := evt.Content.AsCallCandidates() - if candidates == nil { - logger.Error("failed to parse candidates") - return - } conference := r.conferences[candidates.ConfID] if conference == nil { @@ -104,10 +94,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. case event.ToDeviceCallSelectAnswer.Type: selectAnswer := evt.Content.AsCallSelectAnswer() - if selectAnswer == nil { - logger.Error("failed to parse select_answer") - return - } conference := r.conferences[selectAnswer.ConfID] if conference == nil { @@ -125,10 +111,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Someone tries to inform us about leaving an existing call. case event.ToDeviceCallHangup.Type: hangup := evt.Content.AsCallHangup() - if hangup == nil { - logger.Error("failed to parse hangup") - return - } conference := r.conferences[hangup.ConfID] if conference == nil { From 0b75dfc68c3cdc80fc8d6c07ee4618155a922306 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 25 Nov 2022 18:13:38 +0100 Subject: [PATCH 13/62] conference: rename peer message sink for clarity --- pkg/conference/conference.go | 6 +++--- pkg/conference/processor.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/conference/conference.go b/pkg/conference/conference.go index 0e8cd36..7e9195c 100644 --- a/pkg/conference/conference.go +++ b/pkg/conference/conference.go @@ -32,7 +32,7 @@ type Conference struct { config Config signaling signaling.MatrixSignaling participants map[ParticipantID]*Participant - peerEvents chan common.Message[ParticipantID, peer.MessageContent] + peerMessages chan common.Message[ParticipantID, peer.MessageContent] logger *logrus.Entry } @@ -42,7 +42,7 @@ func NewConference(confID string, config Config, signaling signaling.MatrixSigna config: config, signaling: signaling, participants: make(map[ParticipantID]*Participant), - peerEvents: make(chan common.Message[ParticipantID, peer.MessageContent]), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } @@ -74,7 +74,7 @@ func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent * // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. sdpAnswer, err := func() (*webrtc.SessionDescription, error) { if participant == nil { - messageSink := common.NewMessageSink(participantID, c.peerEvents) + messageSink := common.NewMessageSink(participantID, c.peerMessages) peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) if err != nil { diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index b838f74..ab47d5a 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -13,7 +13,7 @@ import ( func (c *Conference) processMessages() { for { // Read a message from the participant in the room (our local counterpart of it) - message := <-c.peerEvents + message := <-c.peerMessages c.processPeerMessage(message) } } From e0f4cb760f8235381da9f993cb3da8d924101526 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 25 Nov 2022 23:25:56 +0100 Subject: [PATCH 14/62] conference: ensure a proper conference lifetime --- pkg/conference/conference.go | 234 ----------------------------------- pkg/conference/matrix.go | 116 +++++++++++++++++ pkg/conference/processor.go | 33 ++++- pkg/conference/start.go | 57 +++++++++ pkg/conference/state.go | 108 ++++++++++++++++ pkg/router.go | 125 +++++++------------ 6 files changed, 359 insertions(+), 314 deletions(-) delete mode 100644 pkg/conference/conference.go create mode 100644 pkg/conference/matrix.go create mode 100644 pkg/conference/start.go create mode 100644 pkg/conference/state.go diff --git a/pkg/conference/conference.go b/pkg/conference/conference.go deleted file mode 100644 index 7e9195c..0000000 --- a/pkg/conference/conference.go +++ /dev/null @@ -1,234 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package conference - -import ( - "github.com/matrix-org/waterfall/pkg/common" - "github.com/matrix-org/waterfall/pkg/peer" - "github.com/matrix-org/waterfall/pkg/signaling" - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "maunium.net/go/mautrix/event" -) - -// A single conference. Call and conference mean the same in context of Matrix. -type Conference struct { - id string - config Config - signaling signaling.MatrixSignaling - participants map[ParticipantID]*Participant - peerMessages chan common.Message[ParticipantID, peer.MessageContent] - logger *logrus.Entry -} - -func NewConference(confID string, config Config, signaling signaling.MatrixSignaling) *Conference { - conference := &Conference{ - id: confID, - config: config, - signaling: signaling, - participants: make(map[ParticipantID]*Participant), - peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), - } - - // Start conference "main loop". - go conference.processMessages() - return conference -} - -// New participant tries to join the conference. -func (c *Conference) OnNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) { - logger := c.logger.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - }) - - // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, - // any existing calls from this device in this call should be terminated. - participant := c.getParticipant(participantID, nil) - if participant != nil { - if participant.remoteSessionID == inviteEvent.SenderSessionID { - c.logger.Errorf("Found existing participant with equal DeviceID and SessionID") - } else { - c.removeParticipant(participantID) - participant = nil - } - } - - // If participant exists still exists, then it means that the client does not behave properly. - // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. - sdpAnswer, err := func() (*webrtc.SessionDescription, error) { - if participant == nil { - messageSink := common.NewMessageSink(participantID, c.peerMessages) - - peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) - if err != nil { - return nil, err - } - - participant = &Participant{ - id: participantID, - peer: peer, - logger: logger, - remoteSessionID: inviteEvent.SenderSessionID, - streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), - } - - c.participants[participantID] = participant - return answer, nil - } else { - answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) - if err != nil { - return nil, err - } - return answer, nil - } - }() - if err != nil { - logger.WithError(err).Errorf("Failed to process SDP offer") - return - } - - // Send the answer back to the remote peer. - recipient := participant.asMatrixRecipient() - streamMetadata := c.getAvailableStreamsFor(participantID) - c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpAnswer.SDP) -} - -// Process new ICE candidates received from Matrix signaling (from the remote peer) and forward them to -// our internal peer connection. -func (c *Conference) OnCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { - if participant := c.getParticipant(participantID, nil); participant != nil { - // Convert the candidates to the WebRTC format. - candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) - for i, candidate := range ev.Candidates { - SDPMLineIndex := uint16(candidate.SDPMLineIndex) - candidates[i] = webrtc.ICECandidateInit{ - Candidate: candidate.Candidate, - SDPMid: &candidate.SDPMID, - SDPMLineIndex: &SDPMLineIndex, - } - } - - participant.peer.ProcessNewRemoteCandidates(candidates) - } -} - -// Process an acknowledgement from the remote peer that the SDP answer has been received -// and that the call can now proceed. -func (c *Conference) OnSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { - if participant := c.getParticipant(participantID, nil); participant != nil { - if ev.SelectedPartyID != participantID.DeviceID.String() { - c.logger.WithFields(logrus.Fields{ - "device_id": ev.SelectedPartyID, - "user_id": participantID, - }).Errorf("Call was answered on a different device, kicking this peer") - participant.peer.Terminate() - } - } -} - -// Process a message from the remote peer telling that it wants to hang up the call. -func (c *Conference) OnHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { - c.removeParticipant(participantID) -} - -func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { - participant, ok := c.participants[participantID] - if !ok { - logEntry := c.logger.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - }) - - if optionalErrorMessage != nil { - logEntry.WithError(optionalErrorMessage) - } else { - logEntry.Error("Participant not found") - } - - return nil - } - - return participant -} - -// Helper to terminate and remove a participant from the conference. -func (c *Conference) removeParticipant(participantID ParticipantID) { - participant := c.getParticipant(participantID, nil) - if participant == nil { - return - } - - // Terminate the participant and remove it from the list. - participant.peer.Terminate() - delete(c.participants, participantID) - - // Inform the other participants about updated metadata (since the participant left - // the corresponding streams of the participant are no longer available, so we're informing - // others about it). - c.resendMetadataToAllExcept(participantID) - - // Remove the participant's tracks from all participants who might have subscribed to them. - obsoleteTracks := maps.Values(participant.publishedTracks) - for _, otherParticipant := range c.participants { - otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) - } -} - -// Helper to get the list of available streams for a given participant, i.e. the list of streams -// that a given participant **can subscribe to**. -func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.CallSDPStreamMetadata { - streamsMetadata := make(event.CallSDPStreamMetadata) - for id, participant := range c.participants { - if forParticipant != id { - for streamID, metadata := range participant.streamMetadata { - streamsMetadata[streamID] = metadata - } - } - } - - return streamsMetadata -} - -// Helper that returns the list of streams inside this conference that match the given stream IDs. -func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { - tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) - for _, participant := range c.participants { - // Check if this participant has any of the tracks that we're looking for. - for _, identifier := range identifiers { - if track, ok := participant.publishedTracks[identifier]; ok { - tracks = append(tracks, track) - } - } - } - return tracks -} - -// Helper that sends current metadata about all available tracks to all participants except a given one. -func (c *Conference) resendMetadataToAllExcept(exceptMe ParticipantID) { - for participantID, participant := range c.participants { - if participantID != exceptMe { - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(participantID), - }) - } - } -} diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go new file mode 100644 index 0000000..7998626 --- /dev/null +++ b/pkg/conference/matrix.go @@ -0,0 +1,116 @@ +package conference + +import ( + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MessageContent interface{} + +type IncomingMatrixMessage struct { + UserID id.UserID + Content MessageContent +} + +// New participant tries to join the conference. +func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) error { + logger := c.logger.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, + // any existing calls from this device in this call should be terminated. + participant := c.getParticipant(participantID, nil) + if participant != nil { + if participant.remoteSessionID == inviteEvent.SenderSessionID { + c.logger.Errorf("Found existing participant with equal DeviceID and SessionID") + } else { + c.removeParticipant(participantID) + participant = nil + } + } + + // If participant exists still exists, then it means that the client does not behave properly. + // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. + sdpAnswer, err := func() (*webrtc.SessionDescription, error) { + if participant == nil { + messageSink := common.NewMessageSink(participantID, c.peerMessages) + + peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) + if err != nil { + return nil, err + } + + participant = &Participant{ + id: participantID, + peer: peer, + logger: logger, + remoteSessionID: inviteEvent.SenderSessionID, + streamMetadata: inviteEvent.SDPStreamMetadata, + publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), + } + + c.participants[participantID] = participant + return answer, nil + } else { + answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) + if err != nil { + return nil, err + } + return answer, nil + } + }() + if err != nil { + logger.WithError(err).Errorf("Failed to process SDP offer") + return err + } + + // Send the answer back to the remote peer. + recipient := participant.asMatrixRecipient() + streamMetadata := c.getAvailableStreamsFor(participantID) + c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpAnswer.SDP) + return nil +} + +// Process new ICE candidates received from Matrix signaling (from the remote peer) and forward them to +// our internal peer connection. +func (c *Conference) onCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + // Convert the candidates to the WebRTC format. + candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) + for i, candidate := range ev.Candidates { + SDPMLineIndex := uint16(candidate.SDPMLineIndex) + candidates[i] = webrtc.ICECandidateInit{ + Candidate: candidate.Candidate, + SDPMid: &candidate.SDPMID, + SDPMLineIndex: &SDPMLineIndex, + } + } + + participant.peer.ProcessNewRemoteCandidates(candidates) + } +} + +// Process an acknowledgement from the remote peer that the SDP answer has been received +// and that the call can now proceed. +func (c *Conference) onSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + if ev.SelectedPartyID != participantID.DeviceID.String() { + c.logger.WithFields(logrus.Fields{ + "device_id": ev.SelectedPartyID, + "user_id": participantID, + }).Errorf("Call was answered on a different device, kicking this peer") + participant.peer.Terminate() + } + } +} + +// Process a message from the remote peer telling that it wants to hang up the call. +func (c *Conference) onHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { + c.removeParticipant(participantID) +} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index ab47d5a..acddc94 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -10,11 +10,23 @@ import ( "maunium.net/go/mautrix/event" ) +// Listen on messages from incoming channels and process them. +// This is essentially the main loop of the conference. +// If this function returns, the conference is over. func (c *Conference) processMessages() { for { - // Read a message from the participant in the room (our local counterpart of it) - message := <-c.peerMessages - c.processPeerMessage(message) + select { + case msg := <-c.peerMessages: + c.processPeerMessage(msg) + case msg := <-c.matrixBus: + c.processMatrixMessage(msg) + } + + // If there are no more participants, stop the conference. + if len(c.participants) == 0 { + c.logger.Info("No more participants, stopping the conference") + return + } } } @@ -142,3 +154,18 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa c.resendMetadataToAllExcept(participant.id) } } + +func (c *Conference) processMatrixMessage(msg IncomingMatrixMessage) { + switch ev := msg.Content.(type) { + case *event.CallInviteEventContent: + c.onNewParticipant(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + case *event.CallCandidatesEventContent: + c.onCandidates(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + case *event.CallSelectAnswerEventContent: + c.onSelectAnswer(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + case *event.CallHangupEventContent: + c.onHangup(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + default: + c.logger.Errorf("Unexpected event type: %T", ev) + } +} diff --git a/pkg/conference/start.go b/pkg/conference/start.go new file mode 100644 index 0000000..3832ab7 --- /dev/null +++ b/pkg/conference/start.go @@ -0,0 +1,57 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conference + +import ( + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Starts a new conference or fails and returns an error. +func StartConference( + confID string, + config Config, + signaling signaling.MatrixSignaling, + UserID id.UserID, + inviteEvent *event.CallInviteEventContent, +) (chan<- IncomingMatrixMessage, error) { + matrixBus := make(chan IncomingMatrixMessage) + + conference := &Conference{ + id: confID, + config: config, + signaling: signaling, + participants: make(map[ParticipantID]*Participant), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), + matrixBus: matrixBus, + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + } + + participantID := ParticipantID{UserID: UserID, DeviceID: inviteEvent.DeviceID} + if err := conference.onNewParticipant(participantID, inviteEvent); err != nil { + return nil, err + } + + // Start conference "main loop". + go conference.processMessages() + + return matrixBus, nil +} diff --git a/pkg/conference/state.go b/pkg/conference/state.go new file mode 100644 index 0000000..a185e12 --- /dev/null +++ b/pkg/conference/state.go @@ -0,0 +1,108 @@ +package conference + +import ( + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "maunium.net/go/mautrix/event" +) + +// A single conference. Call and conference mean the same in context of Matrix. +type Conference struct { + id string + config Config + logger *logrus.Entry + + signaling signaling.MatrixSignaling + participants map[ParticipantID]*Participant + + peerMessages chan common.Message[ParticipantID, peer.MessageContent] + matrixBus <-chan IncomingMatrixMessage +} + +func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[participantID] + if !ok { + logEntry := c.logger.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + if optionalErrorMessage != nil { + logEntry.WithError(optionalErrorMessage) + } else { + logEntry.Error("Participant not found") + } + + return nil + } + + return participant +} + +// Helper to terminate and remove a participant from the conference. +func (c *Conference) removeParticipant(participantID ParticipantID) { + participant := c.getParticipant(participantID, nil) + if participant == nil { + return + } + + // Terminate the participant and remove it from the list. + participant.peer.Terminate() + delete(c.participants, participantID) + + // Inform the other participants about updated metadata (since the participant left + // the corresponding streams of the participant are no longer available, so we're informing + // others about it). + c.resendMetadataToAllExcept(participantID) + + // Remove the participant's tracks from all participants who might have subscribed to them. + obsoleteTracks := maps.Values(participant.publishedTracks) + for _, otherParticipant := range c.participants { + otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) + } +} + +// Helper to get the list of available streams for a given participant, i.e. the list of streams +// that a given participant **can subscribe to**. +func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.CallSDPStreamMetadata { + streamsMetadata := make(event.CallSDPStreamMetadata) + for id, participant := range c.participants { + if forParticipant != id { + for streamID, metadata := range participant.streamMetadata { + streamsMetadata[streamID] = metadata + } + } + } + + return streamsMetadata +} + +// Helper that returns the list of streams inside this conference that match the given stream IDs. +func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { + tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) + for _, participant := range c.participants { + // Check if this participant has any of the tracks that we're looking for. + for _, identifier := range identifiers { + if track, ok := participant.publishedTracks[identifier]; ok { + tracks = append(tracks, track) + } + } + } + return tracks +} + +// Helper that sends current metadata about all available tracks to all participants except a given one. +func (c *Conference) resendMetadataToAllExcept(exceptMe ParticipantID) { + for participantID, participant := range c.participants { + if participantID != exceptMe { + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participantID), + }) + } + } +} diff --git a/pkg/router.go b/pkg/router.go index 842a78a..8e168d5 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "github.com/matrix-org/waterfall/pkg/conference" conf "github.com/matrix-org/waterfall/pkg/conference" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" @@ -27,8 +28,8 @@ import ( type Router struct { // Matrix matrix. matrix *signaling.MatrixClient - // All calls currently forwarded by this SFU. - conferences map[string]*conf.Conference + // Sinks of all conferences (all calls that are currently forwarded by this SFU). + conferenceSinks map[string]chan<- conf.IncomingMatrixMessage // Configuration for the calls. config conf.Config } @@ -36,103 +37,73 @@ type Router struct { // Creates a new instance of the SFU with the given configuration. func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { return &Router{ - matrix: matrix, - conferences: make(map[string]*conf.Conference), - config: config, + matrix: matrix, + conferenceSinks: make(map[string]chan<- conference.IncomingMatrixMessage), + config: config, } } // Handles incoming To-Device events that the SFU receives from clients. func (r *Router) handleMatrixEvent(evt *event.Event) { - // TODO: Don't create logger again and again, it might be a bit expensive. + // Check if `conf_id` is present in the message (right messages do have it). + rawConferenceID, ok := evt.Content.Raw["conf_id"] + if !ok { + return + } + + // Try to parse the conference ID without parsing the whole event. + conferenceID, ok := rawConferenceID.(string) + if !ok { + return + } + logger := logrus.WithFields(logrus.Fields{ "type": evt.Type.Type, "user_id": evt.Sender.String(), - "conf_id": evt.Content.Raw["conf_id"], + "conf_id": conferenceID, }) + conference := r.conferenceSinks[conferenceID] + + // Only ToDeviceCallInvite events are allowed to create a new conference, others + // are expected to operate on an existing conference that is running on the SFU. + if conference == nil && evt.Type.Type != event.ToDeviceCallInvite.Type { + logger.Warnf("ignoring %s for an unknown conference %s, ignoring", &event.ToDeviceCallInvite.Type) + return + } + switch evt.Type.Type { // Someone tries to participate in a call (join a call). case event.ToDeviceCallInvite.Type: - invite := evt.Content.AsCallInvite() - // If there is an invitation sent and the conference does not exist, create one. - if conference := r.conferences[invite.ConfID]; conference == nil { - logger.Infof("creating new conference %s", invite.ConfID) - r.conferences[invite.ConfID] = conf.NewConference( - invite.ConfID, + if conference == nil { + logger.Infof("creating new conference %s", conferenceID) + conferenceSink, err := conf.StartConference( + conferenceID, r.config, - r.matrix.CreateForConference(invite.ConfID), + r.matrix.CreateForConference(conferenceID), + evt.Sender, evt.Content.AsCallInvite(), ) - } - - peerID := conf.ParticipantID{ - UserID: evt.Sender, - DeviceID: invite.DeviceID, - } + if err != nil { + logger.WithError(err).Errorf("failed to start conference %s", conferenceID) + return + } - // Inform conference about incoming participant. - r.conferences[invite.ConfID].OnNewParticipant(peerID, invite) - - // Someone tries to send ICE candidates to the existing call. - case event.ToDeviceCallCandidates.Type: - candidates := evt.Content.AsCallCandidates() - - conference := r.conferences[candidates.ConfID] - if conference == nil { - logger.Errorf("received candidates for unknown conference %s", candidates.ConfID) + r.conferenceSinks[conferenceID] = conferenceSink return } - peerID := conf.ParticipantID{ - UserID: evt.Sender, - DeviceID: candidates.DeviceID, - } - - conference.OnCandidates(peerID, candidates) - - // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. + conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallInvite()} + case event.ToDeviceCallCandidates.Type: + // Someone tries to send ICE candidates to the existing call. + conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallCandidates()} case event.ToDeviceCallSelectAnswer.Type: - selectAnswer := evt.Content.AsCallSelectAnswer() - - conference := r.conferences[selectAnswer.ConfID] - if conference == nil { - logger.Errorf("received select_answer for unknown conference %s", selectAnswer.ConfID) - return - } - - peerID := conf.ParticipantID{ - UserID: evt.Sender, - DeviceID: selectAnswer.DeviceID, - } - - conference.OnSelectAnswer(peerID, selectAnswer) - - // Someone tries to inform us about leaving an existing call. + // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. + conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallSelectAnswer()} case event.ToDeviceCallHangup.Type: - hangup := evt.Content.AsCallHangup() - - conference := r.conferences[hangup.ConfID] - if conference == nil { - logger.Errorf("received hangup for unknown conference %s", hangup.ConfID) - return - } - - peerID := conf.ParticipantID{ - UserID: evt.Sender, - DeviceID: hangup.DeviceID, - } - - conference.OnHangup(peerID, hangup) - - // Events that we **should not** receive! - case event.ToDeviceCallNegotiate.Type: - logger.Warn("ignoring negotiate event that must be handled over the data channel") - case event.ToDeviceCallReject.Type: - logger.Warn("ignoring reject event that must be handled over the data channel") - case event.ToDeviceCallAnswer.Type: - logger.Warn("ignoring event as we are always the ones sending the SDP answer at the moment") + // Someone tries to inform us about leaving an existing call. + conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallHangup()} default: - logger.Warnf("ignoring unexpected event: %s", evt.Type.Type) + logger.Warnf("ignoring event that we must not receive: %s", evt.Type.Type) } } From 34ae746619cd9223cfcef7fa3a70d2c1038cb733 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Sat, 26 Nov 2022 21:34:23 +0100 Subject: [PATCH 15/62] refactor(conference): rename matrix message struct --- pkg/conference/matrix.go | 2 +- pkg/conference/processor.go | 2 +- pkg/conference/start.go | 4 ++-- pkg/conference/state.go | 2 +- pkg/router.go | 12 ++++++------ 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index 7998626..b5119d8 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -11,7 +11,7 @@ import ( type MessageContent interface{} -type IncomingMatrixMessage struct { +type MatrixMessage struct { UserID id.UserID Content MessageContent } diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index acddc94..898bdda 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -155,7 +155,7 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa } } -func (c *Conference) processMatrixMessage(msg IncomingMatrixMessage) { +func (c *Conference) processMatrixMessage(msg MatrixMessage) { switch ev := msg.Content.(type) { case *event.CallInviteEventContent: c.onNewParticipant(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 3832ab7..a384690 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -32,8 +32,8 @@ func StartConference( signaling signaling.MatrixSignaling, UserID id.UserID, inviteEvent *event.CallInviteEventContent, -) (chan<- IncomingMatrixMessage, error) { - matrixBus := make(chan IncomingMatrixMessage) +) (chan<- MatrixMessage, error) { + matrixBus := make(chan MatrixMessage) conference := &Conference{ id: confID, diff --git a/pkg/conference/state.go b/pkg/conference/state.go index a185e12..37b08a0 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -20,7 +20,7 @@ type Conference struct { participants map[ParticipantID]*Participant peerMessages chan common.Message[ParticipantID, peer.MessageContent] - matrixBus <-chan IncomingMatrixMessage + matrixBus <-chan MatrixMessage } func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { diff --git a/pkg/router.go b/pkg/router.go index 8e168d5..fa8588a 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -29,7 +29,7 @@ type Router struct { // Matrix matrix. matrix *signaling.MatrixClient // Sinks of all conferences (all calls that are currently forwarded by this SFU). - conferenceSinks map[string]chan<- conf.IncomingMatrixMessage + conferenceSinks map[string]chan<- conf.MatrixMessage // Configuration for the calls. config conf.Config } @@ -38,7 +38,7 @@ type Router struct { func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { return &Router{ matrix: matrix, - conferenceSinks: make(map[string]chan<- conference.IncomingMatrixMessage), + conferenceSinks: make(map[string]chan<- conference.MatrixMessage), config: config, } } @@ -93,16 +93,16 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallInvite()} + conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallInvite()} case event.ToDeviceCallCandidates.Type: // Someone tries to send ICE candidates to the existing call. - conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallCandidates()} + conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallCandidates()} case event.ToDeviceCallSelectAnswer.Type: // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. - conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallSelectAnswer()} + conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallSelectAnswer()} case event.ToDeviceCallHangup.Type: // Someone tries to inform us about leaving an existing call. - conference <- conf.IncomingMatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallHangup()} + conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallHangup()} default: logger.Warnf("ignoring event that we must not receive: %s", evt.Type.Type) } From 877f7c2067df1b6150759fc8de4a678a87adfd12 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Sat, 26 Nov 2022 22:46:10 +0100 Subject: [PATCH 16/62] conference: inform owner when the conference ended This is a quick and a bit dirty implementation of the desired logic. We definitely must improve it in the future! --- pkg/common/channel.go | 37 ++++++++++++ pkg/conference/processor.go | 18 +++++- pkg/conference/start.go | 27 +++++---- pkg/conference/state.go | 11 ++-- pkg/main.go | 7 ++- pkg/router.go | 116 ++++++++++++++++++++++++++++++------ 6 files changed, 179 insertions(+), 37 deletions(-) create mode 100644 pkg/common/channel.go diff --git a/pkg/common/channel.go b/pkg/common/channel.go new file mode 100644 index 0000000..782d6dd --- /dev/null +++ b/pkg/common/channel.go @@ -0,0 +1,37 @@ +package common + +import "sync/atomic" + +// Creates a new channel, returns two counterparts of it where one can only send and another can only receive. +// Unlike traditional Go channels, these allow the receiver to mark the channel as closed which would then fail +// to send any messages to the channel over `Send“. +func NewChannel[M any]() (Sender[M], Receiver[M]) { + channel := make(chan M) + closed := &atomic.Bool{} + sender := Sender[M]{channel, closed} + receiver := Receiver[M]{channel, closed} + return sender, receiver +} + +type Sender[M any] struct { + channel chan<- M + receiverClosed *atomic.Bool +} + +func (s *Sender[M]) Send(message M) *M { + if !s.receiverClosed.Load() { + s.channel <- message + return nil + } else { + return &message + } +} + +type Receiver[M any] struct { + Channel <-chan M + receiverClosed *atomic.Bool +} + +func (r *Receiver[M]) Close() { + r.receiverClosed.Store(true) +} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 898bdda..2471717 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -18,13 +18,29 @@ func (c *Conference) processMessages() { select { case msg := <-c.peerMessages: c.processPeerMessage(msg) - case msg := <-c.matrixBus: + case msg := <-c.matrixMessages.Channel: c.processMatrixMessage(msg) } // If there are no more participants, stop the conference. if len(c.participants) == 0 { c.logger.Info("No more participants, stopping the conference") + // Close the channel so that the sender can't push any messages. + c.matrixMessages.Close() + + // Let's read remaining messages from the channel (otherwise the caller will be + // blocked in case of unbuffered channels). + var message *MatrixMessage + select { + case msg := <-c.matrixMessages.Channel: + *message = msg + default: + // Ok, no messages in the queue, nice. + } + + // Send the information that we ended to the owner and pass the message + // that we did not process (so that we don't drop it silently). + c.endNotifier.Notify(message) return } } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index a384690..822daca 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -30,19 +30,21 @@ func StartConference( confID string, config Config, signaling signaling.MatrixSignaling, + conferenceEndNotifier ConferenceEndNotifier, UserID id.UserID, inviteEvent *event.CallInviteEventContent, -) (chan<- MatrixMessage, error) { - matrixBus := make(chan MatrixMessage) +) (*common.Sender[MatrixMessage], error) { + sender, receiver := common.NewChannel[MatrixMessage]() conference := &Conference{ - id: confID, - config: config, - signaling: signaling, - participants: make(map[ParticipantID]*Participant), - peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), - matrixBus: matrixBus, - logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + id: confID, + config: config, + signaling: signaling, + matrixMessages: receiver, + endNotifier: conferenceEndNotifier, + participants: make(map[ParticipantID]*Participant), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } participantID := ParticipantID{UserID: UserID, DeviceID: inviteEvent.DeviceID} @@ -53,5 +55,10 @@ func StartConference( // Start conference "main loop". go conference.processMessages() - return matrixBus, nil + return &sender, nil +} + +type ConferenceEndNotifier interface { + // Called when the conference ends. + Notify(unread *MatrixMessage) } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 37b08a0..2f067bc 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -12,15 +12,16 @@ import ( // A single conference. Call and conference mean the same in context of Matrix. type Conference struct { - id string - config Config - logger *logrus.Entry + id string + config Config + logger *logrus.Entry + endNotifier ConferenceEndNotifier signaling signaling.MatrixSignaling participants map[ParticipantID]*Participant - peerMessages chan common.Message[ParticipantID, peer.MessageContent] - matrixBus <-chan MatrixMessage + peerMessages chan common.Message[ParticipantID, peer.MessageContent] + matrixMessages common.Receiver[MatrixMessage] } func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { diff --git a/pkg/main.go b/pkg/main.go index 855ba1d..6e09354 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/waterfall/pkg/config" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" ) func main() { @@ -72,8 +73,10 @@ func main() { matrixClient := signaling.NewMatrixClient(config.Matrix) // Create a router to route incoming To-Device messages to the right conference. - router := newRouter(matrixClient, config.Conference) + routerChannel := newRouter(matrixClient, config.Conference) // Start matrix client sync. This function will block until the sync fails. - matrixClient.RunSync(router.handleMatrixEvent) + matrixClient.RunSync(func(e *event.Event) { + routerChannel <- e + }) } diff --git a/pkg/router.go b/pkg/router.go index fa8588a..ef45eb4 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -17,30 +17,57 @@ limitations under the License. package main import ( - "github.com/matrix-org/waterfall/pkg/conference" + "github.com/matrix-org/waterfall/pkg/common" conf "github.com/matrix-org/waterfall/pkg/conference" "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" ) +type Conference = common.Sender[conf.MatrixMessage] + // The top-level state of the Router. type Router struct { // Matrix matrix. matrix *signaling.MatrixClient // Sinks of all conferences (all calls that are currently forwarded by this SFU). - conferenceSinks map[string]chan<- conf.MatrixMessage + conferenceSinks map[string]*Conference // Configuration for the calls. config conf.Config + // A channel to serialize all incoming events to the Router. + channel chan RouterMessage } // Creates a new instance of the SFU with the given configuration. -func newRouter(matrix *signaling.MatrixClient, config conf.Config) *Router { - return &Router{ +func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- RouterMessage { + router := &Router{ matrix: matrix, - conferenceSinks: make(map[string]chan<- conference.MatrixMessage), + conferenceSinks: make(map[string]*common.Sender[conf.MatrixMessage]), config: config, + channel: make(chan RouterMessage), } + + // Start the main loop of the Router. + go func() { + for msg := range router.channel { + switch msg := msg.(type) { + // To-Device message received from the remote peer. + case MatrixMessage: + router.handleMatrixEvent(msg) + // One of the conferences has ended. + case ConferenceEndedMessage: + // Remove the conference that ended from the list. + delete(router.conferenceSinks, msg.conferenceID) + // Process the message that was not read by the conference. + if msg.unread != nil { + // TODO: We must handle this message to avoid glare on session end. + // router.handleMatrixEvent(*msg.unread) + } + } + } + }() + + return router.channel } // Handles incoming To-Device events that the SFU receives from clients. @@ -68,21 +95,15 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Only ToDeviceCallInvite events are allowed to create a new conference, others // are expected to operate on an existing conference that is running on the SFU. if conference == nil && evt.Type.Type != event.ToDeviceCallInvite.Type { - logger.Warnf("ignoring %s for an unknown conference %s, ignoring", &event.ToDeviceCallInvite.Type) - return - } - - switch evt.Type.Type { - // Someone tries to participate in a call (join a call). - case event.ToDeviceCallInvite.Type: - // If there is an invitation sent and the conference does not exist, create one. - if conference == nil { + if evt.Type.Type == event.ToDeviceCallInvite.Type { logger.Infof("creating new conference %s", conferenceID) conferenceSink, err := conf.StartConference( conferenceID, r.config, r.matrix.CreateForConference(conferenceID), - evt.Sender, evt.Content.AsCallInvite(), + createConferenceEndNotifier(conferenceID, r.channel), + evt.Sender, + evt.Content.AsCallInvite(), ) if err != nil { logger.WithError(err).Errorf("failed to start conference %s", conferenceID) @@ -93,17 +114,74 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallInvite()} + logger.Warnf("ignoring %s for an unknown conference %s, ignoring", &event.ToDeviceCallInvite.Type) + return + } + + // A helper function to deal with messages that can't be sent due to the conference closed. + // Not a function due to the need to capture local environment. + sendToConference := func(eventContent conf.MessageContent) { + // At this point the conference is not nil. + // Let's check if the channel is still available. + if conference.Send(conf.MatrixMessage{UserID: evt.Sender, Content: eventContent}) != nil { + // If sending failed, then the conference is over. + delete(r.conferenceSinks, conferenceID) + // Since we were not able to send the message, let's re-process it now. + // Note, we probably do not want to block here! + // TODO: Do it better (use buffered channels or something). + r.handleMatrixEvent(evt) + } + } + + switch evt.Type.Type { + // Someone tries to participate in a call (join a call). + case event.ToDeviceCallInvite.Type: + // If there is an invitation sent and the conference does not exist, create one. + sendToConference(evt.Content.AsCallInvite()) case event.ToDeviceCallCandidates.Type: // Someone tries to send ICE candidates to the existing call. - conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallCandidates()} + sendToConference(evt.Content.AsCallCandidates()) case event.ToDeviceCallSelectAnswer.Type: // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. - conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallSelectAnswer()} + sendToConference(evt.Content.AsCallSelectAnswer()) case event.ToDeviceCallHangup.Type: // Someone tries to inform us about leaving an existing call. - conference <- conf.MatrixMessage{UserID: evt.Sender, Content: evt.Content.AsCallHangup()} + sendToConference(evt.Content.AsCallHangup()) default: logger.Warnf("ignoring event that we must not receive: %s", evt.Type.Type) } } + +type RouterMessage = interface{} + +type MatrixMessage = *event.Event + +// Message that is sent from the conference when the conference is ended. +type ConferenceEndedMessage struct { + // The ID of the conference that has ended. + conferenceID string + // A message (or messages in future) that has not been processed (if any). + unread *conf.MatrixMessage +} + +// A simple wrapper around channel that contains the ID of the conference that sent the message. +type ConferenceEndNotifier struct { + conferenceID string + channel chan<- interface{} +} + +// Crates a simple notifier with a conference with a given ID. +func createConferenceEndNotifier(conferenceID string, channel chan<- RouterMessage) *ConferenceEndNotifier { + return &ConferenceEndNotifier{ + conferenceID: conferenceID, + channel: channel, + } +} + +// A function that a conference calls when it is ended. +func (c *ConferenceEndNotifier) Notify(unread *conf.MatrixMessage) { + c.channel <- ConferenceEndedMessage{ + conferenceID: c.conferenceID, + unread: unread, + } +} From e2d73f9205c03139b1f0b7e7a33043d478c747ce Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 13:14:59 +0100 Subject: [PATCH 17/62] general: remove custom logger It seems like it does not really make logs appear more readable unfortunately, especially in docker environment. We may want to switch to a different structured logging instead. See https://github.com/matrix-org/waterfall/issues/50 --- pkg/logger.go | 74 ---------------------------------------------- pkg/main.go | 3 +- scripts/profile.sh | 2 +- scripts/run.sh | 2 +- 4 files changed, 3 insertions(+), 78 deletions(-) delete mode 100644 pkg/logger.go diff --git a/pkg/logger.go b/pkg/logger.go deleted file mode 100644 index 9d725ae..0000000 --- a/pkg/logger.go +++ /dev/null @@ -1,74 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "fmt" - - "github.com/sirupsen/logrus" -) - -func InitLogging(logTime bool) { - formatter := new(CustomTextFormatter) - formatter.logTime = logTime - logrus.SetFormatter(formatter) -} - -type CustomTextFormatter struct { - logrus.TextFormatter - logTime bool -} - -func (f *CustomTextFormatter) Format(entry *logrus.Entry) ([]byte, error) { - // TODO: Use colors and make it pretty - level := entry.Level - timestamp := entry.Time.Format("2006-01-02 15:04:05") - confID := entry.Data["conf_id"] - userID := entry.Data["user_id"] - - logLine := fmt.Sprintf("%s|", level) - - if f.logTime { - logLine += fmt.Sprintf("%s|", timestamp) - } - - if confID != nil { - logLine += fmt.Sprintf("%v|", confID) - } - - if userID != nil { - logLine += fmt.Sprintf("%v|", userID) - } - - logLine += fmt.Sprintf(" %v ", entry.Message) - - fields := "" - - for key, value := range entry.Data { - if key != "conf_id" && key != "user_id" { - fields += fmt.Sprintf("%v=%v ", key, value) - } - } - - if fields != "" { - logLine += fmt.Sprintf("| %s", fields) - } - - logLine += "\n" - - return []byte(logLine), nil -} diff --git a/pkg/main.go b/pkg/main.go index 6e09354..1fe4531 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -31,7 +31,6 @@ import ( func main() { // Parse command line flags. var ( - logTime = flag.Bool("logTime", false, "whether or not to print time and date in logs") configFilePath = flag.String("config", "config.yaml", "configuration file path") cpuProfile = flag.String("cpuProfile", "", "write CPU profile to `file`") memProfile = flag.String("memProfile", "", "write memory profile to `file`") @@ -39,7 +38,7 @@ func main() { flag.Parse() // Initialize logging subsystem (formatting, global logging framework etc). - InitLogging(*logTime) + logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) // Define functions that are called before exiting. // This is useful to stop the profiler if it's enabled. diff --git a/scripts/profile.sh b/scripts/profile.sh index 0042a33..7bc0ee9 100755 --- a/scripts/profile.sh +++ b/scripts/profile.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./pkg/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof --logTime +go run ./pkg/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof diff --git a/scripts/run.sh b/scripts/run.sh index 16e9b1f..6381271 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./pkg --logTime +go run ./pkg From bd023b6bfe418aae0a6c54e3103ea5a3cf8c2957 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 13:39:36 +0100 Subject: [PATCH 18/62] logger: skip checking the TTY Otherwise the logs in docker would be messed up. --- pkg/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/main.go b/pkg/main.go index 1fe4531..b677702 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -38,7 +38,7 @@ func main() { flag.Parse() // Initialize logging subsystem (formatting, global logging framework etc). - logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) + logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true, ForceColors: true}) // Define functions that are called before exiting. // This is useful to stop the profiler if it's enabled. From b53d1979f6eb00aeaac14a94c3acb50a31398d23 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 13:45:28 +0100 Subject: [PATCH 19/62] router: fix a typo for the OnInvite event --- pkg/router.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/router.go b/pkg/router.go index ef45eb4..6bfbcd9 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -94,7 +94,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Only ToDeviceCallInvite events are allowed to create a new conference, others // are expected to operate on an existing conference that is running on the SFU. - if conference == nil && evt.Type.Type != event.ToDeviceCallInvite.Type { + if conference == nil && evt.Type.Type == event.ToDeviceCallInvite.Type { if evt.Type.Type == event.ToDeviceCallInvite.Type { logger.Infof("creating new conference %s", conferenceID) conferenceSink, err := conf.StartConference( From 40cb5fe68a1f087ab98d5f05c953a5ef077e59ab Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 13:56:22 +0100 Subject: [PATCH 20/62] router: fix a typo on the type of the printf Otherwise, the address of the event type would be typed. --- pkg/router.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/router.go b/pkg/router.go index 6bfbcd9..19b651f 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -114,7 +114,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } - logger.Warnf("ignoring %s for an unknown conference %s, ignoring", &event.ToDeviceCallInvite.Type) + logger.Warnf("ignoring %s since the conference is unknown", event.ToDeviceCallInvite.Type) return } From c629c7e3151e6513698a9f9eb5529404724c0326 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 14:11:24 +0100 Subject: [PATCH 21/62] refactor(conference): simplify new participant --- pkg/conference/matrix.go | 59 +++++++++++++++++++--------------------- pkg/conference/state.go | 9 ++---- 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index b5119d8..1372a29 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -25,49 +25,46 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. - participant := c.getParticipant(participantID, nil) - if participant != nil { + if participant := c.participants[participantID]; participant != nil { if participant.remoteSessionID == inviteEvent.SenderSessionID { c.logger.Errorf("Found existing participant with equal DeviceID and SessionID") } else { c.removeParticipant(participantID) - participant = nil } } + participant := c.participants[participantID] + var sdpAnswer *webrtc.SessionDescription + // If participant exists still exists, then it means that the client does not behave properly. // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. - sdpAnswer, err := func() (*webrtc.SessionDescription, error) { - if participant == nil { - messageSink := common.NewMessageSink(participantID, c.peerMessages) - - peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) - if err != nil { - return nil, err - } + if participant != nil { + answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) + if err != nil { + logger.WithError(err).Errorf("Failed to process SDP offer") + return err + } + sdpAnswer = answer + } else { + messageSink := common.NewMessageSink(participantID, c.peerMessages) - participant = &Participant{ - id: participantID, - peer: peer, - logger: logger, - remoteSessionID: inviteEvent.SenderSessionID, - streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), - } + peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) + if err != nil { + logger.WithError(err).Errorf("Failed to process SDP offer") + return err + } - c.participants[participantID] = participant - return answer, nil - } else { - answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) - if err != nil { - return nil, err - } - return answer, nil + participant = &Participant{ + id: participantID, + peer: peer, + logger: logger, + remoteSessionID: inviteEvent.SenderSessionID, + streamMetadata: inviteEvent.SDPStreamMetadata, + publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), } - }() - if err != nil { - logger.WithError(err).Errorf("Failed to process SDP offer") - return err + + c.participants[participantID] = participant + sdpAnswer = answer } // Send the answer back to the remote peer. diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 2f067bc..342f68e 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -27,15 +27,10 @@ type Conference struct { func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { participant, ok := c.participants[participantID] if !ok { - logEntry := c.logger.WithFields(logrus.Fields{ - "user_id": participantID.UserID, - "device_id": participantID.DeviceID, - }) - if optionalErrorMessage != nil { - logEntry.WithError(optionalErrorMessage) + c.logger.WithError(optionalErrorMessage) } else { - logEntry.Error("Participant not found") + c.logger.Error("Participant not found") } return nil From 327419494867c665c00e6860e16a03e1ab69a6f9 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 14:21:00 +0100 Subject: [PATCH 22/62] conference: add additional logging --- pkg/conference/matrix.go | 11 ++++++++++- pkg/conference/processor.go | 21 +++++++++++++++++++++ pkg/peer/webrtc.go | 5 +---- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index 1372a29..ba506d5 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -23,6 +23,8 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * "device_id": participantID.DeviceID, }) + logger.Info("Incoming call invite") + // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, // any existing calls from this device in this call should be terminated. if participant := c.participants[participantID]; participant != nil { @@ -78,6 +80,8 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * // our internal peer connection. func (c *Conference) onCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { if participant := c.getParticipant(participantID, nil); participant != nil { + participant.logger.Info("Received remote ICE candidates") + // Convert the candidates to the WebRTC format. candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) for i, candidate := range ev.Candidates { @@ -97,6 +101,8 @@ func (c *Conference) onCandidates(participantID ParticipantID, ev *event.CallCan // and that the call can now proceed. func (c *Conference) onSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { if participant := c.getParticipant(participantID, nil); participant != nil { + participant.logger.Info("Received remote answer selection") + if ev.SelectedPartyID != participantID.DeviceID.String() { c.logger.WithFields(logrus.Fields{ "device_id": ev.SelectedPartyID, @@ -109,5 +115,8 @@ func (c *Conference) onSelectAnswer(participantID ParticipantID, ev *event.CallS // Process a message from the remote peer telling that it wants to hang up the call. func (c *Conference) onHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { - c.removeParticipant(participantID) + if participant := c.participants[participantID]; participant != nil { + participant.logger.Info("Received remote hangup") + c.removeParticipant(participantID) + } } diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 2471717..f5ff2dc 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -57,13 +57,16 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe // determine the actual type of the message. switch msg := message.Content.(type) { case peer.JoinedTheCall: + participant.logger.Info("Joined the call") c.resendMetadataToAllExcept(participant.id) case peer.LeftTheCall: + participant.logger.Info("Left the call") c.removeParticipant(message.Sender) c.signaling.SendHangup(participant.asMatrixRecipient(), event.CallHangupUnknownError) case peer.NewTrackPublished: + participant.logger.Infof("Published new track: %s", msg.Track.ID()) key := event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -77,6 +80,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe participant.publishedTracks[key] = msg.Track case peer.PublishedTrackFailed: + participant.logger.Infof("Failed published track: %s", msg.Track.ID()) delete(participant.publishedTracks, event.SFUTrackDescription{ StreamID: msg.Track.StreamID(), TrackID: msg.Track.ID(), @@ -91,6 +95,8 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe } case peer.NewICECandidate: + participant.logger.Info("Received a new local ICE candidate") + // Convert WebRTC ICE candidate to Matrix ICE candidate. jsonCandidate := msg.Candidate.ToJSON() candidates := []event.CallCandidate{{ @@ -101,10 +107,13 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) case peer.ICEGatheringComplete: + participant.logger.Info("Completed local ICE gathering") + // Send an empty array of candidates to indicate that ICE gathering is complete. c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) case peer.RenegotiationRequired: + participant.logger.Info("Started renegotiation") participant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationOffer, SDP: msg.Offer.SDP, @@ -112,6 +121,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe }) case peer.DataChannelMessage: + participant.logger.Info("Sent data channel message") var sfuMessage event.SFUMessage if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { c.logger.Errorf("Failed to unmarshal SFU message: %v", err) @@ -121,6 +131,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.handleDataChannelMessage(participant, sfuMessage) case peer.DataChannelAvailable: + participant.logger.Info("Connected data channel") participant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationMetadata, Metadata: c.getAvailableStreamsFor(participant.id), @@ -135,6 +146,8 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessage event.SFUMessage) { switch sfuMessage.Op { case event.SFUOperationSelect: + participant.logger.Info("Sent select request over DC") + // Get the tracks that correspond to the tracks that the participant wants to receive. for _, track := range c.getTracks(sfuMessage.Start) { if err := participant.peer.SubscribeTo(track); err != nil { @@ -144,12 +157,16 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa } case event.SFUOperationAnswer: + participant.logger.Info("Sent SDP answer over DC") + if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { participant.logger.Errorf("Failed to set SDP answer: %v", err) return } case event.SFUOperationPublish: + participant.logger.Info("Sent SDP offer over DC") + answer, err := participant.peer.ProcessSDPOffer(sfuMessage.SDP) if err != nil { participant.logger.Errorf("Failed to set SDP offer: %v", err) @@ -162,10 +179,14 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa }) case event.SFUOperationUnpublish: + participant.logger.Info("Sent unpublish over DC") + // TODO: Clarify the semantics of unpublish. case event.SFUOperationAlive: // FIXME: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: + participant.logger.Info("Sent metadata over DC") + participant.streamMetadata = sfuMessage.Metadata c.resendMetadataToAllExcept(participant.id) } diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 741f135..82238c7 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -71,6 +71,7 @@ func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver func (p *Peer[ID]) onICECandidateGathered(candidate *webrtc.ICECandidate) { if candidate == nil { p.logger.Info("ICE candidate gathering finished") + p.sink.Send(ICEGatheringComplete{}) return } @@ -112,10 +113,6 @@ func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) func (p *Peer[ID]) onICEGatheringStateChanged(state webrtc.ICEGathererState) { p.logger.WithField("state", state).Debug("ICE gathering state changed") - - if state == webrtc.ICEGathererStateComplete { - p.sink.Send(ICEGatheringComplete{}) - } } func (p *Peer[ID]) onSignalingStateChanged(state webrtc.SignalingState) { From 6b31a98ba880b698e5ddebf6aa84af5835edd9e6 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 14:55:17 +0100 Subject: [PATCH 23/62] router: ignore messages for unknown conferences --- pkg/router.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/router.go b/pkg/router.go index 19b651f..04d5509 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -118,6 +118,12 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { return } + // All other events are expected to be handled by the existing conference. + if conference == nil { + logger.Warnf("ignoring %s since the conference is unknown", evt.Type.Type) + return + } + // A helper function to deal with messages that can't be sent due to the conference closed. // Not a function due to the need to capture local environment. sendToConference := func(eventContent conf.MessageContent) { From b991a3a8f2e934bcf802eb52e8226712665e7fa2 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 15:06:40 +0100 Subject: [PATCH 24/62] signaling: fix wrongly set party ID --- pkg/signaling/matrix.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 0c84331..91b4017 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -61,7 +61,7 @@ func (m *MatrixForConference) SendSDPAnswer( ) { eventContent := &event.Content{ Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), Answer: event.CallData{ Type: "answer", SDP: sdp, @@ -76,7 +76,7 @@ func (m *MatrixForConference) SendSDPAnswer( func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), Candidates: candidates, }, } @@ -87,7 +87,7 @@ func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candi func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), Candidates: []event.CallCandidate{{Candidate: ""}}, }, } @@ -98,7 +98,7 @@ func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRe func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { eventContent := &event.Content{ Parsed: event.CallHangupEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.DeviceID, recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), Reason: reason, }, } @@ -106,17 +106,14 @@ func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event m.sendToDevice(recipient, event.CallHangup, eventContent) } -func (m *MatrixForConference) createBaseEventContent( - destDeviceID id.DeviceID, - destSessionID id.SessionID, -) event.BaseCallEventContent { +func (m *MatrixForConference) createBaseEventContent(destSessionID id.SessionID) event.BaseCallEventContent { return event.BaseCallEventContent{ CallID: m.conferenceID, ConfID: m.conferenceID, DeviceID: m.client.DeviceID, SenderSessionID: LocalSessionID, DestSessionID: destSessionID, - PartyID: string(destDeviceID), + PartyID: string(m.client.DeviceID), Version: event.CallVersion("1"), } } From 9d456742eec54397c4dcf42169959d29a1305d21 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 15:20:06 +0100 Subject: [PATCH 25/62] ice: add an empty user fragment to the ICE --- pkg/conference/matrix.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index ba506d5..bfaa955 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -87,9 +87,10 @@ func (c *Conference) onCandidates(participantID ParticipantID, ev *event.CallCan for i, candidate := range ev.Candidates { SDPMLineIndex := uint16(candidate.SDPMLineIndex) candidates[i] = webrtc.ICECandidateInit{ - Candidate: candidate.Candidate, - SDPMid: &candidate.SDPMID, - SDPMLineIndex: &SDPMLineIndex, + Candidate: candidate.Candidate, + SDPMid: &candidate.SDPMID, + SDPMLineIndex: &SDPMLineIndex, + UsernameFragment: new(string), } } From 1c3b7024f15c8f8ff4de1e0afe25b0118cca5ce9 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 15:27:42 +0100 Subject: [PATCH 26/62] peer: use GatheringCompletePromise from Pion This is just to test if the outcome of ICE negotiation changes. So far the connection does not get established despite an exchange of ICE candidates. --- pkg/peer/peer.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index a6ef022..839d6f5 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -180,12 +180,16 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return nil, ErrCantCreateAnswer } + // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? + gatherComplete := webrtc.GatheringCompletePromise(p.peerConnection) + if err := p.peerConnection.SetLocalDescription(answer); err != nil { p.logger.WithError(err).Error("failed to set local description") return nil, ErrCantSetLocalDescription } - // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? + // Block until ICE Gathering is complete. + <-gatherComplete sdpAnswer := p.peerConnection.LocalDescription() if sdpAnswer == nil { From 82dac24f97286a1e3cd80d057029c322bc194a47 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 15:39:18 +0100 Subject: [PATCH 27/62] Revert "peer: use GatheringCompletePromise from Pion" This reverts commit 9d20fb3a3f64878d2fd1aa2c3cff4e9a69c28c2e. --- pkg/peer/peer.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 839d6f5..a6ef022 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -180,16 +180,12 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return nil, ErrCantCreateAnswer } - // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? - gatherComplete := webrtc.GatheringCompletePromise(p.peerConnection) - if err := p.peerConnection.SetLocalDescription(answer); err != nil { p.logger.WithError(err).Error("failed to set local description") return nil, ErrCantSetLocalDescription } - // Block until ICE Gathering is complete. - <-gatherComplete + // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? sdpAnswer := p.peerConnection.LocalDescription() if sdpAnswer == nil { From 415e24275cf6011e5b0c452e0460e324ba5f7033 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 16:47:51 +0100 Subject: [PATCH 28/62] conference: add additional call invite logging --- pkg/conference/matrix.go | 1 + pkg/main.go | 3 +++ pkg/peer/peer.go | 3 +-- pkg/router.go | 29 ++++++++++++----------------- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index bfaa955..f4ba6f5 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -72,6 +72,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * // Send the answer back to the remote peer. recipient := participant.asMatrixRecipient() streamMetadata := c.getAvailableStreamsFor(participantID) + participant.logger.WithField("sdpAnswer", sdpAnswer.SDP).Debug("Sending SDP answer") c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpAnswer.SDP) return nil } diff --git a/pkg/main.go b/pkg/main.go index b677702..9d92ed3 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -40,6 +40,9 @@ func main() { // Initialize logging subsystem (formatting, global logging framework etc). logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true, ForceColors: true}) + // Temporarily enable debug logging. + logrus.SetLevel(logrus.DebugLevel) + // Define functions that are called before exiting. // This is useful to stop the profiler if it's enabled. deferred_functions := []func(){} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index a6ef022..bfc2e05 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -165,6 +165,7 @@ func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { // Applies the sdp offer received from the remote peer and generates an SDP answer. func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, error) { + p.logger.WithField("sdpOffer", sdpOffer).Debug("processing SDP offer") err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: sdpOffer, @@ -185,8 +186,6 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return nil, ErrCantSetLocalDescription } - // TODO: Do we really need to call `webrtc.GatheringCompletePromise` here? - sdpAnswer := p.peerConnection.LocalDescription() if sdpAnswer == nil { p.logger.WithError(err).Error("could not generate a local description") diff --git a/pkg/router.go b/pkg/router.go index 04d5509..3d1b9c2 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -95,26 +95,21 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // Only ToDeviceCallInvite events are allowed to create a new conference, others // are expected to operate on an existing conference that is running on the SFU. if conference == nil && evt.Type.Type == event.ToDeviceCallInvite.Type { - if evt.Type.Type == event.ToDeviceCallInvite.Type { - logger.Infof("creating new conference %s", conferenceID) - conferenceSink, err := conf.StartConference( - conferenceID, - r.config, - r.matrix.CreateForConference(conferenceID), - createConferenceEndNotifier(conferenceID, r.channel), - evt.Sender, - evt.Content.AsCallInvite(), - ) - if err != nil { - logger.WithError(err).Errorf("failed to start conference %s", conferenceID) - return - } - - r.conferenceSinks[conferenceID] = conferenceSink + logger.Infof("creating new conference %s", conferenceID) + conferenceSink, err := conf.StartConference( + conferenceID, + r.config, + r.matrix.CreateForConference(conferenceID), + createConferenceEndNotifier(conferenceID, r.channel), + evt.Sender, + evt.Content.AsCallInvite(), + ) + if err != nil { + logger.WithError(err).Errorf("failed to start conference %s", conferenceID) return } - logger.Warnf("ignoring %s since the conference is unknown", event.ToDeviceCallInvite.Type) + r.conferenceSinks[conferenceID] = conferenceSink return } From cc096927ba20f9f2cc564bd9dbfdb416774d37fd Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 20:28:43 +0100 Subject: [PATCH 29/62] conference: fix unsoundness in channel usage --- pkg/common/channel.go | 2 +- pkg/conference/processor.go | 17 +++++++++-------- pkg/conference/start.go | 4 ++-- pkg/router.go | 12 ++++++------ 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/pkg/common/channel.go b/pkg/common/channel.go index 782d6dd..73dc32f 100644 --- a/pkg/common/channel.go +++ b/pkg/common/channel.go @@ -6,7 +6,7 @@ import "sync/atomic" // Unlike traditional Go channels, these allow the receiver to mark the channel as closed which would then fail // to send any messages to the channel over `Send“. func NewChannel[M any]() (Sender[M], Receiver[M]) { - channel := make(chan M) + channel := make(chan M, 128) closed := &atomic.Bool{} sender := Sender[M]{channel, closed} receiver := Receiver[M]{channel, closed} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index f5ff2dc..2a43726 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -29,18 +29,19 @@ func (c *Conference) processMessages() { c.matrixMessages.Close() // Let's read remaining messages from the channel (otherwise the caller will be - // blocked in case of unbuffered channels). - var message *MatrixMessage - select { - case msg := <-c.matrixMessages.Channel: - *message = msg - default: - // Ok, no messages in the queue, nice. + // blocked in case of unbuffered channels). We must read **all** pending messages. + messages := make([]MatrixMessage, 0) + for { + msg, ok := <-c.matrixMessages.Channel + if !ok { + break + } + messages = append(messages, msg) } // Send the information that we ended to the owner and pass the message // that we did not process (so that we don't drop it silently). - c.endNotifier.Notify(message) + c.endNotifier.Notify(messages) return } } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 822daca..871360b 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -43,7 +43,7 @@ func StartConference( matrixMessages: receiver, endNotifier: conferenceEndNotifier, participants: make(map[ParticipantID]*Participant), - peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent]), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent], 128), logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } @@ -60,5 +60,5 @@ func StartConference( type ConferenceEndNotifier interface { // Called when the conference ends. - Notify(unread *MatrixMessage) + Notify(unread []MatrixMessage) } diff --git a/pkg/router.go b/pkg/router.go index 3d1b9c2..00bebb8 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -44,7 +44,7 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- Router matrix: matrix, conferenceSinks: make(map[string]*common.Sender[conf.MatrixMessage]), config: config, - channel: make(chan RouterMessage), + channel: make(chan RouterMessage, 128), } // Start the main loop of the Router. @@ -59,9 +59,9 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- Router // Remove the conference that ended from the list. delete(router.conferenceSinks, msg.conferenceID) // Process the message that was not read by the conference. - if msg.unread != nil { - // TODO: We must handle this message to avoid glare on session end. - // router.handleMatrixEvent(*msg.unread) + if len(msg.unread) > 0 { + // FIXME: We must handle these messages! + logrus.Warnf("Unread messages: %v", len(msg.unread)) } } } @@ -162,7 +162,7 @@ type ConferenceEndedMessage struct { // The ID of the conference that has ended. conferenceID string // A message (or messages in future) that has not been processed (if any). - unread *conf.MatrixMessage + unread []conf.MatrixMessage } // A simple wrapper around channel that contains the ID of the conference that sent the message. @@ -180,7 +180,7 @@ func createConferenceEndNotifier(conferenceID string, channel chan<- RouterMessa } // A function that a conference calls when it is ended. -func (c *ConferenceEndNotifier) Notify(unread *conf.MatrixMessage) { +func (c *ConferenceEndNotifier) Notify(unread []conf.MatrixMessage) { c.channel <- ConferenceEndedMessage{ conferenceID: c.conferenceID, unread: unread, From db900b2f3778cfc4aef259d86e67396f3ec7a128 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 21:27:16 +0100 Subject: [PATCH 30/62] signaling: add additional debug logs --- pkg/signaling/matrix.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 91b4017..ff0fe49 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -17,6 +17,8 @@ limitations under the License. package signaling import ( + "encoding/json" + "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -133,6 +135,16 @@ func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event }, } + { + // TODO: Remove this once + serialized, err := json.Marshal(sendRequest) + if err != nil { + logger.WithError(err).Error("Failed to serialize to-device message") + return + } + logger.Debugf("Sending to-device message: %s", string(serialized)) + } + if _, err := m.client.SendToDevice(eventType, sendRequest); err != nil { logger.Errorf("failed to send to-device event: %w", err) } From 2f809b1113c540fd09517188471f9a45df024b8a Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 21:54:00 +0100 Subject: [PATCH 31/62] conference: add `call_id` handling The `call_id` does seem to be after all not the same as `conf_id`! --- pkg/conference/matrix.go | 7 ++--- pkg/conference/participant.go | 2 ++ pkg/conference/processor.go | 8 +++--- pkg/conference/start.go | 2 +- pkg/router.go | 51 +++++++++++++++++++++++------------ pkg/signaling/matrix.go | 16 ++++++----- 6 files changed, 55 insertions(+), 31 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index f4ba6f5..ed6e853 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -6,14 +6,14 @@ import ( "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) type MessageContent interface{} type MatrixMessage struct { - UserID id.UserID - Content MessageContent + Sender ParticipantID + Content MessageContent + RawEvent *event.Event } // New participant tries to join the conference. @@ -21,6 +21,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * logger := c.logger.WithFields(logrus.Fields{ "user_id": participantID.UserID, "device_id": participantID.DeviceID, + "call_id": participantID.CallID, }) logger.Info("Incoming call invite") diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index a7b9385..61b9df2 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -16,6 +16,7 @@ import ( type ParticipantID struct { UserID id.UserID DeviceID id.DeviceID + CallID string } // Participant represents a participant in the conference. @@ -32,6 +33,7 @@ func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { return signaling.MatrixRecipient{ UserID: p.id.UserID, DeviceID: p.id.DeviceID, + CallID: p.id.CallID, RemoteSessionID: p.remoteSessionID, } } diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 2a43726..fbaf0b9 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -196,13 +196,13 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa func (c *Conference) processMatrixMessage(msg MatrixMessage) { switch ev := msg.Content.(type) { case *event.CallInviteEventContent: - c.onNewParticipant(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + c.onNewParticipant(msg.Sender, ev) case *event.CallCandidatesEventContent: - c.onCandidates(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + c.onCandidates(msg.Sender, ev) case *event.CallSelectAnswerEventContent: - c.onSelectAnswer(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + c.onSelectAnswer(msg.Sender, ev) case *event.CallHangupEventContent: - c.onHangup(ParticipantID{UserID: msg.UserID, DeviceID: ev.DeviceID}, ev) + c.onHangup(msg.Sender, ev) default: c.logger.Errorf("Unexpected event type: %T", ev) } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 871360b..02d12b2 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -47,7 +47,7 @@ func StartConference( logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } - participantID := ParticipantID{UserID: UserID, DeviceID: inviteEvent.DeviceID} + participantID := ParticipantID{UserID: UserID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} if err := conference.onNewParticipant(participantID, inviteEvent); err != nil { return nil, err } diff --git a/pkg/router.go b/pkg/router.go index 00bebb8..b2a9630 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) type Conference = common.Sender[conf.MatrixMessage] @@ -58,10 +59,11 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- Router case ConferenceEndedMessage: // Remove the conference that ended from the list. delete(router.conferenceSinks, msg.conferenceID) + // Process the message that was not read by the conference. - if len(msg.unread) > 0 { - // FIXME: We must handle these messages! - logrus.Warnf("Unread messages: %v", len(msg.unread)) + for _, msg := range msg.unread { + // TODO: We actually already know the type, so we can do this better. + router.handleMatrixEvent(msg.RawEvent) } } } @@ -72,22 +74,36 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- Router // Handles incoming To-Device events that the SFU receives from clients. func (r *Router) handleMatrixEvent(evt *event.Event) { - // Check if `conf_id` is present in the message (right messages do have it). - rawConferenceID, ok := evt.Content.Raw["conf_id"] - if !ok { - return - } + var ( + conferenceID string + callID string + deviceID string + userID = evt.Sender + ) - // Try to parse the conference ID without parsing the whole event. - conferenceID, ok := rawConferenceID.(string) - if !ok { - return + // Check if `conf_id` is present in the message (right messages do have it). + rawConferenceID, okConferenceId := evt.Content.Raw["conf_id"] + rawCallID, okCallId := evt.Content.Raw["call_id"] + rawDeviceID, okDeviceID := evt.Content.Raw["device_id"] + + if okConferenceId && okCallId && okDeviceID { + // Extract the conference ID from the message. + conferenceID, okConferenceId = rawConferenceID.(string) + callID, okCallId = rawCallID.(string) + deviceID, okDeviceID = rawDeviceID.(string) + + if !okConferenceId || !okCallId || !okDeviceID { + logrus.Warn("Ignoring invalid message without IDs") + return + } } logger := logrus.WithFields(logrus.Fields{ - "type": evt.Type.Type, - "user_id": evt.Sender.String(), - "conf_id": conferenceID, + "type": evt.Type.Type, + "user_id": userID, + "conf_id": conferenceID, + "call_id": callID, + "device_id": deviceID, }) conference := r.conferenceSinks[conferenceID] @@ -101,7 +117,7 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { r.config, r.matrix.CreateForConference(conferenceID), createConferenceEndNotifier(conferenceID, r.channel), - evt.Sender, + userID, evt.Content.AsCallInvite(), ) if err != nil { @@ -122,9 +138,10 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { // A helper function to deal with messages that can't be sent due to the conference closed. // Not a function due to the need to capture local environment. sendToConference := func(eventContent conf.MessageContent) { + sender := conf.ParticipantID{userID, id.DeviceID(deviceID), callID} // At this point the conference is not nil. // Let's check if the channel is still available. - if conference.Send(conf.MatrixMessage{UserID: evt.Sender, Content: eventContent}) != nil { + if conference.Send(conf.MatrixMessage{Content: eventContent, RawEvent: evt, Sender: sender}) != nil { // If sending failed, then the conference is over. delete(r.conferenceSinks, conferenceID) // Since we were not able to send the message, let's re-process it now. diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index ff0fe49..29d2681 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -46,6 +46,7 @@ type MatrixRecipient struct { UserID id.UserID DeviceID id.DeviceID RemoteSessionID id.SessionID + CallID string } // Interface that abstracts sending Send-to-device messages for the conference. @@ -63,7 +64,7 @@ func (m *MatrixForConference) SendSDPAnswer( ) { eventContent := &event.Content{ Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), Answer: event.CallData{ Type: "answer", SDP: sdp, @@ -78,7 +79,7 @@ func (m *MatrixForConference) SendSDPAnswer( func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), Candidates: candidates, }, } @@ -89,7 +90,7 @@ func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candi func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { eventContent := &event.Content{ Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), Candidates: []event.CallCandidate{{Candidate: ""}}, }, } @@ -100,7 +101,7 @@ func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRe func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { eventContent := &event.Content{ Parsed: event.CallHangupEventContent{ - BaseCallEventContent: m.createBaseEventContent(recipient.RemoteSessionID), + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), Reason: reason, }, } @@ -108,9 +109,12 @@ func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event m.sendToDevice(recipient, event.CallHangup, eventContent) } -func (m *MatrixForConference) createBaseEventContent(destSessionID id.SessionID) event.BaseCallEventContent { +func (m *MatrixForConference) createBaseEventContent( + callID string, + destSessionID id.SessionID, +) event.BaseCallEventContent { return event.BaseCallEventContent{ - CallID: m.conferenceID, + CallID: callID, ConfID: m.conferenceID, DeviceID: m.client.DeviceID, SenderSessionID: LocalSessionID, From 01fef5ccef69863911375cfd4dbff6960c170373 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 22:15:20 +0100 Subject: [PATCH 32/62] conference: use SFU device ID on SelectAnswer --- pkg/conference/matrix.go | 4 ++-- pkg/signaling/matrix.go | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index ed6e853..02671b2 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -106,12 +106,12 @@ func (c *Conference) onSelectAnswer(participantID ParticipantID, ev *event.CallS if participant := c.getParticipant(participantID, nil); participant != nil { participant.logger.Info("Received remote answer selection") - if ev.SelectedPartyID != participantID.DeviceID.String() { + if ev.SelectedPartyID != string(c.signaling.DeviceID()) { c.logger.WithFields(logrus.Fields{ "device_id": ev.SelectedPartyID, "user_id": participantID, }).Errorf("Call was answered on a different device, kicking this peer") - participant.peer.Terminate() + c.removeParticipant(participantID) } } } diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 29d2681..40b81cd 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -55,6 +55,7 @@ type MatrixSignaling interface { SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) SendCandidatesGatheringFinished(recipient MatrixRecipient) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) + DeviceID() id.DeviceID } func (m *MatrixForConference) SendSDPAnswer( @@ -109,6 +110,10 @@ func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event m.sendToDevice(recipient, event.CallHangup, eventContent) } +func (m *MatrixForConference) DeviceID() id.DeviceID { + return m.client.DeviceID +} + func (m *MatrixForConference) createBaseEventContent( callID string, destSessionID id.SessionID, From 479863d706c66608e5e5b59cf2b6540517c9f016 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 22:35:32 +0100 Subject: [PATCH 33/62] config: make log level configurable --- pkg/config/config.go | 2 ++ pkg/main.go | 20 +++++++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index d908c34..f9fde9e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,6 +17,8 @@ type Config struct { Matrix signaling.Config `yaml:"matrix"` // Conference (call) configuration. Conference conference.Config `yaml:"conference"` + // Starting from which level to log stuff. + LogLevel string `yaml:"log"` } // Tries to load a config from the `CONFIG` environment variable. diff --git a/pkg/main.go b/pkg/main.go index 9d92ed3..05a8f51 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -40,9 +40,6 @@ func main() { // Initialize logging subsystem (formatting, global logging framework etc). logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true, ForceColors: true}) - // Temporarily enable debug logging. - logrus.SetLevel(logrus.DebugLevel) - // Define functions that are called before exiting. // This is useful to stop the profiler if it's enabled. deferred_functions := []func(){} @@ -71,6 +68,23 @@ func main() { return } + switch config.LogLevel { + case "debug": + logrus.SetLevel(logrus.DebugLevel) + case "info": + logrus.SetLevel(logrus.InfoLevel) + case "warn": + logrus.SetLevel(logrus.WarnLevel) + case "error": + logrus.SetLevel(logrus.ErrorLevel) + case "fatal": + logrus.SetLevel(logrus.FatalLevel) + case "panic": + logrus.SetLevel(logrus.PanicLevel) + default: + logrus.SetLevel(logrus.InfoLevel) + } + // Create matrix client. matrixClient := signaling.NewMatrixClient(config.Matrix) From 7fbacad92e58d6ef1dcba7c9c3aaa28afead2d88 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 28 Nov 2022 23:27:06 +0100 Subject: [PATCH 34/62] channel: improve documentation and API surface --- pkg/common/channel.go | 41 +++++++++++++++++++++++++++++++++---- pkg/conference/processor.go | 15 ++------------ pkg/conference/start.go | 2 +- pkg/router.go | 2 +- 4 files changed, 41 insertions(+), 19 deletions(-) diff --git a/pkg/common/channel.go b/pkg/common/channel.go index 73dc32f..5c702f1 100644 --- a/pkg/common/channel.go +++ b/pkg/common/channel.go @@ -2,22 +2,34 @@ package common import "sync/atomic" +// In Go, unbounded channel means something different than what it means in Rust. +// I.e. unlike Rust, "unbounded" in Go means that the channel has **no buffer**, +// meaning that each attempt to send will block the channel until the receiver +// reads it. Majority of primitives here in `waterfall` are designed under assumption +// that sending is not blocking. +const UnboundedChannelSize = 128 + // Creates a new channel, returns two counterparts of it where one can only send and another can only receive. // Unlike traditional Go channels, these allow the receiver to mark the channel as closed which would then fail // to send any messages to the channel over `Send“. func NewChannel[M any]() (Sender[M], Receiver[M]) { - channel := make(chan M, 128) + channel := make(chan M, UnboundedChannelSize) closed := &atomic.Bool{} sender := Sender[M]{channel, closed} receiver := Receiver[M]{channel, closed} return sender, receiver } +// Sender counterpart of the channel. type Sender[M any] struct { - channel chan<- M + // The channel itself. + channel chan<- M + // Atomic variable that indicates whether the channel is closed. receiverClosed *atomic.Bool } +// Tries to send a message if the channel is not closed. +// Returns the message back if the channel is closed. func (s *Sender[M]) Send(message M) *M { if !s.receiverClosed.Load() { s.channel <- message @@ -27,11 +39,32 @@ func (s *Sender[M]) Send(message M) *M { } } +// The receiver counterpart of the channel. type Receiver[M any] struct { - Channel <-chan M + // The channel itself. It's public, so that we can combine it in `select` statements. + Channel <-chan M + // Atomic variable that indicates whether the channel is closed. receiverClosed *atomic.Bool } -func (r *Receiver[M]) Close() { +// Marks the channel as closed, which means that no messages could be sent via this channel. +// Any attempt to send a message would result in an error. This is similar to closing the +// channel except that we don't close the underlying channel (since in Go receivers can't +// close the channel). +// +// This function reads (in a non-blocking way) all pending messages until blocking. Otherwise, +// they will stay forver in a channel and get lost. +func (r *Receiver[M]) Close() []M { r.receiverClosed.Store(true) + + messages := make([]M, 0) + for { + msg, ok := <-r.Channel + if !ok { + break + } + messages = append(messages, msg) + } + + return messages } diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index fbaf0b9..ca23244 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -26,22 +26,11 @@ func (c *Conference) processMessages() { if len(c.participants) == 0 { c.logger.Info("No more participants, stopping the conference") // Close the channel so that the sender can't push any messages. - c.matrixMessages.Close() - - // Let's read remaining messages from the channel (otherwise the caller will be - // blocked in case of unbuffered channels). We must read **all** pending messages. - messages := make([]MatrixMessage, 0) - for { - msg, ok := <-c.matrixMessages.Channel - if !ok { - break - } - messages = append(messages, msg) - } + unreadMessages := c.matrixMessages.Close() // Send the information that we ended to the owner and pass the message // that we did not process (so that we don't drop it silently). - c.endNotifier.Notify(messages) + c.endNotifier.Notify(unreadMessages) return } } diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 02d12b2..85e38a8 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -43,7 +43,7 @@ func StartConference( matrixMessages: receiver, endNotifier: conferenceEndNotifier, participants: make(map[ParticipantID]*Participant), - peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent], 128), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent], common.UnboundedChannelSize), logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } diff --git a/pkg/router.go b/pkg/router.go index b2a9630..af846db 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -45,7 +45,7 @@ func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- Router matrix: matrix, conferenceSinks: make(map[string]*common.Sender[conf.MatrixMessage]), config: config, - channel: make(chan RouterMessage, 128), + channel: make(chan RouterMessage, common.UnboundedChannelSize), } // Start the main loop of the Router. From a4f420e5f869c3965d4cad205e47a43f63497e46 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 29 Nov 2022 20:22:39 +0100 Subject: [PATCH 35/62] conference: check for nil rtp tracks --- pkg/conference/processor.go | 4 ++++ pkg/conference/state.go | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index ca23244..c631d37 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -140,6 +140,10 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa // Get the tracks that correspond to the tracks that the participant wants to receive. for _, track := range c.getTracks(sfuMessage.Start) { + if track == nil { + participant.logger.Errorf("Bug, track is nil") + } + if err := participant.peer.SubscribeTo(track); err != nil { participant.logger.Errorf("Failed to subscribe to track: %v", err) return diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 342f68e..1cb77d2 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -79,7 +79,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. // Helper that returns the list of streams inside this conference that match the given stream IDs. func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { - tracks := make([]*webrtc.TrackLocalStaticRTP, len(identifiers)) + tracks := make([]*webrtc.TrackLocalStaticRTP, 0) for _, participant := range c.participants { // Check if this participant has any of the tracks that we're looking for. for _, identifier := range identifiers { From 67e469ebfc57471be07301e2e5728a88327059e6 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 29 Nov 2022 21:07:18 +0100 Subject: [PATCH 36/62] conference: fix the subscribe/unsubscribe logic --- pkg/conference/matrix.go | 1 - pkg/conference/processor.go | 39 +++++++++++++++++++++++++------------ pkg/conference/state.go | 1 + pkg/peer/peer.go | 5 ++++- pkg/peer/webrtc.go | 21 +++++++++++--------- pkg/signaling/matrix.go | 12 ------------ 6 files changed, 44 insertions(+), 35 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index 02671b2..535d0ab 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -21,7 +21,6 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * logger := c.logger.WithFields(logrus.Fields{ "user_id": participantID.UserID, "device_id": participantID.DeviceID, - "call_id": participantID.CallID, }) logger.Info("Incoming call invite") diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index c631d37..9cdc300 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer" "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" ) @@ -85,7 +86,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe } case peer.NewICECandidate: - participant.logger.Info("Received a new local ICE candidate") + participant.logger.Debug("Received a new local ICE candidate") // Convert WebRTC ICE candidate to Matrix ICE candidate. jsonCandidate := msg.Candidate.ToJSON() @@ -111,7 +112,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe }) case peer.DataChannelMessage: - participant.logger.Info("Sent data channel message") + participant.logger.Debug("Received data channel message") var sfuMessage event.SFUMessage if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { c.logger.Errorf("Failed to unmarshal SFU message: %v", err) @@ -136,14 +137,28 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessage event.SFUMessage) { switch sfuMessage.Op { case event.SFUOperationSelect: - participant.logger.Info("Sent select request over DC") - - // Get the tracks that correspond to the tracks that the participant wants to receive. - for _, track := range c.getTracks(sfuMessage.Start) { - if track == nil { - participant.logger.Errorf("Bug, track is nil") + participant.logger.Info("Received select request over DC") + + // Find tracks based on what we were asked for. + tracks := c.getTracks(sfuMessage.Start) + + // Let's check if we have all the tracks that we were asked for are there. + // If not, we will list which are not available (later on we must inform participant + // about it unless the participant retries it). + if len(tracks) != len(sfuMessage.Start) { + for _, expected := range sfuMessage.Start { + found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { + return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID + }) + + if found == -1 { + c.logger.Warnf("Track not found: %s", expected.TrackID) + } } + } + // Subscribe to the found tracks. + for _, track := range tracks { if err := participant.peer.SubscribeTo(track); err != nil { participant.logger.Errorf("Failed to subscribe to track: %v", err) return @@ -151,7 +166,7 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa } case event.SFUOperationAnswer: - participant.logger.Info("Sent SDP answer over DC") + participant.logger.Info("Received SDP answer over DC") if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { participant.logger.Errorf("Failed to set SDP answer: %v", err) @@ -159,7 +174,7 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa } case event.SFUOperationPublish: - participant.logger.Info("Sent SDP offer over DC") + participant.logger.Info("Received SDP offer over DC") answer, err := participant.peer.ProcessSDPOffer(sfuMessage.SDP) if err != nil { @@ -173,13 +188,13 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa }) case event.SFUOperationUnpublish: - participant.logger.Info("Sent unpublish over DC") + participant.logger.Info("Received unpublish over DC") // TODO: Clarify the semantics of unpublish. case event.SFUOperationAlive: // FIXME: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: - participant.logger.Info("Sent metadata over DC") + participant.logger.Info("Received metadata over DC") participant.streamMetadata = sfuMessage.Metadata c.resendMetadataToAllExcept(participant.id) diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 1cb77d2..0ae7480 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -88,6 +88,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt } } } + return tracks } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index bfc2e05..b5d6e9e 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -106,6 +106,10 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. for _, presentTrack := range p.peerConnection.GetSenders() { + if presentTrack.Track() == nil { + continue + } + for _, trackToUnsubscribe := range tracks { presentTrackID, presentStreamID := presentTrack.Track().ID(), presentTrack.Track().StreamID() trackID, streamID := trackToUnsubscribe.ID(), trackToUnsubscribe.StreamID() @@ -165,7 +169,6 @@ func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { // Applies the sdp offer received from the remote peer and generates an SDP answer. func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, error) { - p.logger.WithField("sdpOffer", sdpOffer).Debug("processing SDP offer") err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: sdpOffer, diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 82238c7..52a86da 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -23,6 +23,7 @@ func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver rtcp := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}} if rtcpSendErr := p.peerConnection.WriteRTCP(rtcp); rtcpSendErr != nil { p.logger.Errorf("Failed to send RTCP PLI: %v", rtcpSendErr) + return } } }() @@ -56,12 +57,14 @@ func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver p.logger.WithError(readErr).Error("failed to read from remote track") } p.sink.Send(PublishedTrackFailed{Track: localTrack}) + return } // ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet. if _, err = localTrack.Write(rtpBuf[:index]); err != nil && !errors.Is(err, io.ErrClosedPipe) { p.logger.WithError(err).Error("failed to write to local track") p.sink.Send(PublishedTrackFailed{Track: localTrack}) + return } } }() @@ -98,7 +101,7 @@ func (p *Peer[ID]) onNegotiationNeeded() { // A callback that is called once we receive an ICE connection state change for this peer connection. func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { - p.logger.WithField("state", state).Debug("ICE connection state changed") + p.logger.WithField("state", state).Info("ICE connection state changed") // TODO: Ask Simon if we should do it here as in the previous implementation. switch state { @@ -120,7 +123,7 @@ func (p *Peer[ID]) onSignalingStateChanged(state webrtc.SignalingState) { } func (p *Peer[ID]) onConnectionStateChanged(state webrtc.PeerConnectionState) { - p.logger.WithField("state", state).Debug("connection state changed") + p.logger.WithField("state", state).Info("Connection state changed") switch state { case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: @@ -136,33 +139,33 @@ func (p *Peer[ID]) onDataChannelReady(dc *webrtc.DataChannel) { defer p.dataChannelMutex.Unlock() if p.dataChannel != nil { - p.logger.Error("data channel already exists") + p.logger.Error("Data channel already exists") p.dataChannel.Close() return } p.dataChannel = dc - p.logger.WithField("label", dc.Label()).Info("data channel ready") + p.logger.WithField("label", dc.Label()).Info("Data channel ready") dc.OnOpen(func() { - p.logger.Info("data channel opened") + p.logger.Info("Data channel opened") p.sink.Send(DataChannelAvailable{}) }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { - p.logger.WithField("message", msg).Debug("data channel message received") + p.logger.WithField("message", msg).Debug("Data channel message received") if msg.IsString { p.sink.Send(DataChannelMessage{Message: string(msg.Data)}) } else { - p.logger.Warn("data channel message is not a string, ignoring") + p.logger.Warn("Data channel message is not a string, ignoring") } }) dc.OnError(func(err error) { - p.logger.WithError(err).Error("data channel error") + p.logger.WithError(err).Error("Data channel error") }) dc.OnClose(func() { - p.logger.Info("data channel closed") + p.logger.Info("Data channel closed") }) } diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 40b81cd..31aef72 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -17,8 +17,6 @@ limitations under the License. package signaling import ( - "encoding/json" - "github.com/sirupsen/logrus" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -144,16 +142,6 @@ func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event }, } - { - // TODO: Remove this once - serialized, err := json.Marshal(sendRequest) - if err != nil { - logger.WithError(err).Error("Failed to serialize to-device message") - return - } - logger.Debugf("Sending to-device message: %s", string(serialized)) - } - if _, err := m.client.SendToDevice(eventType, sendRequest); err != nil { logger.Errorf("failed to send to-device event: %w", err) } From ff687e2f229cbcc0dc24e4b5d8f29ee15df0745b Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Tue, 29 Nov 2022 21:25:34 +0100 Subject: [PATCH 37/62] peer: fix a potential segfault --- pkg/peer/peer.go | 11 ++++++----- pkg/router.go | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index b5d6e9e..cb1c737 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -105,16 +105,17 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { // Unsubscribes from the given list of tracks. func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. - for _, presentTrack := range p.peerConnection.GetSenders() { - if presentTrack.Track() == nil { - continue + for _, sender := range p.peerConnection.GetSenders() { + currentTrack := sender.Track() + if currentTrack == nil { + return } for _, trackToUnsubscribe := range tracks { - presentTrackID, presentStreamID := presentTrack.Track().ID(), presentTrack.Track().StreamID() + presentTrackID, presentStreamID := currentTrack.ID(), currentTrack.StreamID() trackID, streamID := trackToUnsubscribe.ID(), trackToUnsubscribe.StreamID() if presentTrackID == trackID && presentStreamID == streamID { - if err := p.peerConnection.RemoveTrack(presentTrack); err != nil { + if err := p.peerConnection.RemoveTrack(sender); err != nil { p.logger.WithError(err).Error("failed to remove track") } } diff --git a/pkg/router.go b/pkg/router.go index af846db..bfc5b61 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -102,7 +102,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { "type": evt.Type.Type, "user_id": userID, "conf_id": conferenceID, - "call_id": callID, "device_id": deviceID, }) From d5ffe81999070b7be1c7b3fea0a3c1e480589007 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 14:13:40 +0100 Subject: [PATCH 38/62] peer: don't fail on certain RTCP write errors --- pkg/peer/webrtc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 52a86da..fea064a 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -21,8 +21,8 @@ func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver ticker := time.NewTicker(time.Millisecond * 500) // every 500ms for range ticker.C { rtcp := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}} - if rtcpSendErr := p.peerConnection.WriteRTCP(rtcp); rtcpSendErr != nil { - p.logger.Errorf("Failed to send RTCP PLI: %v", rtcpSendErr) + if err := p.peerConnection.WriteRTCP(rtcp); err != nil && !errors.Is(err, io.ErrClosedPipe) { + p.logger.Errorf("Failed to send RTCP PLI: %v", err) return } } From 5563dd544e99d812fa631c74a3b2ab78a3949ec8 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 14:16:51 +0100 Subject: [PATCH 39/62] peer: use sdpAnswer from crateAnswer Instead of calling `LocalDescription()` to get it again. It does not seem to make any difference in our particular case? --- pkg/peer/peer.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index cb1c737..80351b8 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -190,11 +190,5 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return nil, ErrCantSetLocalDescription } - sdpAnswer := p.peerConnection.LocalDescription() - if sdpAnswer == nil { - p.logger.WithError(err).Error("could not generate a local description") - return nil, ErrCantCreateLocalDescription - } - - return sdpAnswer, nil + return &answer, nil } From ab53c8e46be64ee5e48b35caf780c689bed9bf69 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 16:03:12 +0100 Subject: [PATCH 40/62] conference: handle metadata of streams properly Don't include tracks in the metadata that are not yet published (for which we don't have any remote streams available). Also, inform about metadata changes once tracks get published and unpublished. --- pkg/conference/processor.go | 4 +++- pkg/conference/state.go | 17 ++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 9cdc300..6c0289f 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -49,7 +49,6 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe switch msg := message.Content.(type) { case peer.JoinedTheCall: participant.logger.Info("Joined the call") - c.resendMetadataToAllExcept(participant.id) case peer.LeftTheCall: participant.logger.Info("Left the call") @@ -69,6 +68,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe } participant.publishedTracks[key] = msg.Track + c.resendMetadataToAllExcept(participant.id) case peer.PublishedTrackFailed: participant.logger.Infof("Failed published track: %s", msg.Track.ID()) @@ -85,6 +85,8 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) } + c.resendMetadataToAllExcept(participant.id) + case peer.NewICECandidate: participant.logger.Debug("Received a new local ICE candidate") diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 0ae7480..36368ec 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -63,13 +63,24 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { } // Helper to get the list of available streams for a given participant, i.e. the list of streams -// that a given participant **can subscribe to**. +// that a given participant **can subscribe to**. Each stream may have multiple tracks. func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.CallSDPStreamMetadata { streamsMetadata := make(event.CallSDPStreamMetadata) for id, participant := range c.participants { + // Skip us. As we know about our own tracks. if forParticipant != id { - for streamID, metadata := range participant.streamMetadata { - streamsMetadata[streamID] = metadata + // Now, find out which of published tracks belong to the streams for which we have metadata + // available and construct a metadata map for a given participant based on that. + for _, track := range participant.publishedTracks { + trackID, streamID := track.ID(), track.StreamID() + + if metadata, ok := streamsMetadata[track.StreamID()]; ok { + metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} + streamsMetadata[streamID] = metadata + } else if metadata, ok := participant.streamMetadata[streamID]; ok { + metadata.Tracks = event.CallSDPStreamMetadataTracks{trackID: event.CallSDPStreamMetadataTrack{}} + streamsMetadata[streamID] = metadata + } } } } From 0c8d52879811e77727686c368b73f717e6054668 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 16:24:53 +0100 Subject: [PATCH 41/62] conference: add hacky way to send unknown metadata This is a temporary measure to fix the screen sharing. --- pkg/conference/state.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 36368ec..129cc9b 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -74,12 +74,22 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. for _, track := range participant.publishedTracks { trackID, streamID := track.ID(), track.StreamID() - if metadata, ok := streamsMetadata[track.StreamID()]; ok { + if metadata, ok := streamsMetadata[streamID]; ok { metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} streamsMetadata[streamID] = metadata } else if metadata, ok := participant.streamMetadata[streamID]; ok { metadata.Tracks = event.CallSDPStreamMetadataTracks{trackID: event.CallSDPStreamMetadataTrack{}} streamsMetadata[streamID] = metadata + } else { + participant.logger.Warnf("Don't have metadata for stream %s", streamID) + // FIXME: A hacky way to nevertheless send a metadata about the stream and track for which we have + // no metadata. This is against the MSC actually since we know nothing about the stream in + // this case. But it was implemented like this in the `main` branch of the SFU. + streamsMetadata[streamID] = event.CallSDPStreamMetadataObject{ + UserID: participant.id.UserID, + DeviceID: participant.id.DeviceID, + Tracks: event.CallSDPStreamMetadataTracks{trackID: event.CallSDPStreamMetadataTrack{}}, + } } } } From bd313c0c6811f9a80a855ac0de9ba035bcf48ac6 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 16:54:23 +0100 Subject: [PATCH 42/62] conference: update TODOs and FIXMEs --- pkg/conference/participant.go | 2 +- pkg/peer/webrtc.go | 4 ++-- pkg/router.go | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 61b9df2..747a267 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -45,7 +45,7 @@ func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) { } if err := p.peer.SendOverDataChannel(string(jsonToSend)); err != nil { - // FIXME: We must buffer the message in this case and re-send it once the data channel is recovered! + // TODO: We must buffer the message in this case and re-send it once the data channel is recovered! p.logger.Error("Failed to send data channel message") } } diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index fea064a..1a62358 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -103,10 +103,10 @@ func (p *Peer[ID]) onNegotiationNeeded() { func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { p.logger.WithField("state", state).Info("ICE connection state changed") - // TODO: Ask Simon if we should do it here as in the previous implementation. switch state { case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateDisconnected: - // TODO: We may want to treat it as an opportunity for the ICE restart instead. + // TODO: Ask Simon if we should do it here as in the previous implementation. + // Ideally we want to perform an ICE restart here. // p.notify <- PeerLeftTheCall{sender: p.data} case webrtc.ICEConnectionStateCompleted, webrtc.ICEConnectionStateConnected: // FIXME: Start keep-alive timer over the data channel to check the connecitons that hanged. diff --git a/pkg/router.go b/pkg/router.go index bfc5b61..c502b18 100644 --- a/pkg/router.go +++ b/pkg/router.go @@ -145,7 +145,6 @@ func (r *Router) handleMatrixEvent(evt *event.Event) { delete(r.conferenceSinks, conferenceID) // Since we were not able to send the message, let's re-process it now. // Note, we probably do not want to block here! - // TODO: Do it better (use buffered channels or something). r.handleMatrixEvent(evt) } } From 1b7bdbedf35d5eec50ad7c9cbd5fbcc1ff1a6c73 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 17:13:08 +0100 Subject: [PATCH 43/62] conference: attach stream metadata to offers --- pkg/conference/processor.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 6c0289f..b585f22 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -185,8 +185,9 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa } participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: answer.SDP, + Op: event.SFUOperationAnswer, + SDP: answer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), }) case event.SFUOperationUnpublish: @@ -196,7 +197,11 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa case event.SFUOperationAlive: // FIXME: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: - participant.logger.Info("Received metadata over DC") + streamIDs := make([]string, 0, len(sfuMessage.Metadata)) + for streamID := range sfuMessage.Metadata { + streamIDs = append(streamIDs, streamID) + } + participant.logger.Infof("Received metadata over DC: %v", streamIDs) participant.streamMetadata = sfuMessage.Metadata c.resendMetadataToAllExcept(participant.id) From 39deed0da3a4ec2fadccb4624d79ba5a2634555d Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Wed, 30 Nov 2022 17:29:37 +0100 Subject: [PATCH 44/62] conference: accept the metadata from the DC offer --- pkg/conference/processor.go | 8 ++------ pkg/conference/state.go | 8 -------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index b585f22..b4d0e5b 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -184,6 +184,8 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa return } + participant.streamMetadata = sfuMessage.Metadata + participant.sendDataChannelMessage(event.SFUMessage{ Op: event.SFUOperationAnswer, SDP: answer.SDP, @@ -197,12 +199,6 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa case event.SFUOperationAlive: // FIXME: Handle the heartbeat message here (updating the last timestamp etc). case event.SFUOperationMetadata: - streamIDs := make([]string, 0, len(sfuMessage.Metadata)) - for streamID := range sfuMessage.Metadata { - streamIDs = append(streamIDs, streamID) - } - participant.logger.Infof("Received metadata over DC: %v", streamIDs) - participant.streamMetadata = sfuMessage.Metadata c.resendMetadataToAllExcept(participant.id) } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 129cc9b..8ec176e 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -82,14 +82,6 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. streamsMetadata[streamID] = metadata } else { participant.logger.Warnf("Don't have metadata for stream %s", streamID) - // FIXME: A hacky way to nevertheless send a metadata about the stream and track for which we have - // no metadata. This is against the MSC actually since we know nothing about the stream in - // this case. But it was implemented like this in the `main` branch of the SFU. - streamsMetadata[streamID] = event.CallSDPStreamMetadataObject{ - UserID: participant.id.UserID, - DeviceID: participant.id.DeviceID, - Tracks: event.CallSDPStreamMetadataTracks{trackID: event.CallSDPStreamMetadataTrack{}}, - } } } } From 74014f0e2dc4201d35132438d57bf7f5898b05a0 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 1 Dec 2022 00:37:56 +0100 Subject: [PATCH 45/62] minor: fix typo in a comment in a signaling comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Šimon Brandner --- pkg/signaling/matrix.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go index 31aef72..245f022 100644 --- a/pkg/signaling/matrix.go +++ b/pkg/signaling/matrix.go @@ -31,7 +31,7 @@ type MatrixForConference struct { conferenceID string } -// Create a new Matrix client that abstarcts outgoing Matrix messages from a given conference. +// Create a new Matrix client that abstracts outgoing Matrix messages from a given conference. func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConference { return &MatrixForConference{ client: m.client, From 5fe188d90c84200bc109355a9777ea35d54f13b4 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 1 Dec 2022 17:22:52 +0100 Subject: [PATCH 46/62] Update the documentation in a conference state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Šimon Brandner --- pkg/conference/state.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 8ec176e..5acf97d 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -90,7 +90,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. return streamsMetadata } -// Helper that returns the list of streams inside this conference that match the given stream IDs. +// Helper that returns the list of streams inside this conference that match the given stream IDs and track IDs. func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { tracks := make([]*webrtc.TrackLocalStaticRTP, 0) for _, participant := range c.participants { From 88997da6d6afccd58096e341058f8916361bc5b5 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 1 Dec 2022 21:02:59 +0100 Subject: [PATCH 47/62] minor: rename `UserID` to `userID` --- pkg/conference/start.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/conference/start.go b/pkg/conference/start.go index 85e38a8..91eadac 100644 --- a/pkg/conference/start.go +++ b/pkg/conference/start.go @@ -31,7 +31,7 @@ func StartConference( config Config, signaling signaling.MatrixSignaling, conferenceEndNotifier ConferenceEndNotifier, - UserID id.UserID, + userID id.UserID, inviteEvent *event.CallInviteEventContent, ) (*common.Sender[MatrixMessage], error) { sender, receiver := common.NewChannel[MatrixMessage]() @@ -47,7 +47,7 @@ func StartConference( logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), } - participantID := ParticipantID{UserID: UserID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} + participantID := ParticipantID{UserID: userID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} if err := conference.onNewParticipant(participantID, inviteEvent); err != nil { return nil, err } From d2cce02d05b8400ead165d8ceed82a9e32bbfc5b Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Thu, 1 Dec 2022 23:37:30 +0100 Subject: [PATCH 48/62] peer: implement heartbeat handling for keepalive Note that we start the keepalive routine once the peer is created. We do it like this since the keepalive deadline is actually quite high and I would say that if within that deadline no heartbeat messages were sent, then we can consider the connection as stalled. I.e. starting the keepalive timer only once the peer is connected is like sparing a second that a peer normally needs to establish a connection? --- pkg/conference/config.go | 2 +- pkg/conference/matrix.go | 5 ++++- pkg/conference/processor.go | 8 ++++---- pkg/peer/keepalive.go | 23 +++++++++++++++++++++++ pkg/peer/messages.go | 5 ++++- pkg/peer/peer.go | 14 ++++++++++++++ pkg/peer/webrtc.go | 3 ++- 7 files changed, 52 insertions(+), 8 deletions(-) create mode 100644 pkg/peer/keepalive.go diff --git a/pkg/conference/config.go b/pkg/conference/config.go index 81952af..a239d60 100644 --- a/pkg/conference/config.go +++ b/pkg/conference/config.go @@ -3,6 +3,6 @@ package conference // Configuration for the group conferences (calls). type Config struct { // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. + // from the client for this duration, the connection is considered dead (in seconds). KeepAliveTimeout int `yaml:"timeout"` } diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index 535d0ab..61bd40c 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -1,6 +1,8 @@ package conference import ( + "time" + "github.com/matrix-org/waterfall/pkg/common" "github.com/matrix-org/waterfall/pkg/peer" "github.com/pion/webrtc/v3" @@ -50,7 +52,8 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * } else { messageSink := common.NewMessageSink(participantID, c.peerMessages) - peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger) + keepAliveDeadline := time.Duration(c.config.KeepAliveTimeout) * time.Second + peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger, keepAliveDeadline) if err != nil { logger.WithError(err).Errorf("Failed to process SDP offer") return err diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index b4d0e5b..8a5b681 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -51,9 +51,9 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe participant.logger.Info("Joined the call") case peer.LeftTheCall: - participant.logger.Info("Left the call") + participant.logger.Info("Left the call: %s", msg.Reason) c.removeParticipant(message.Sender) - c.signaling.SendHangup(participant.asMatrixRecipient(), event.CallHangupUnknownError) + c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) case peer.NewTrackPublished: participant.logger.Infof("Published new track: %s", msg.Track.ID()) @@ -195,9 +195,9 @@ func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessa case event.SFUOperationUnpublish: participant.logger.Info("Received unpublish over DC") - // TODO: Clarify the semantics of unpublish. case event.SFUOperationAlive: - // FIXME: Handle the heartbeat message here (updating the last timestamp etc). + participant.peer.ProcessHeartbeat() + case event.SFUOperationMetadata: participant.streamMetadata = sfuMessage.Metadata c.resendMetadataToAllExcept(participant.id) diff --git a/pkg/peer/keepalive.go b/pkg/peer/keepalive.go new file mode 100644 index 0000000..b66f2ed --- /dev/null +++ b/pkg/peer/keepalive.go @@ -0,0 +1,23 @@ +package peer + +import "time" + +type HeartBeat struct{} + +// Starts a goroutine that will execute `onDeadLine` closure in case nothing has been published +// to the `heartBeat` channel for `deadline` duration. The goroutine stops once the channel is closed. +func startKeepAlive(deadline time.Duration, heartBeat <-chan HeartBeat, onDeadLine func()) { + go func() { + for { + select { + case <-time.After(deadline): + onDeadLine() + return + case _, ok := <-heartBeat: + if !ok { + return + } + } + } + }() +} diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 7d8adf8..6a05a82 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -2,6 +2,7 @@ package peer import ( "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" ) // Due to the limitation of Go, we're using the `interface{}` to be able to use switch the actual @@ -10,7 +11,9 @@ type MessageContent = interface{} type JoinedTheCall struct{} -type LeftTheCall struct{} +type LeftTheCall struct { + Reason event.CallHangupReason +} type NewTrackPublished struct { Track *webrtc.TrackLocalStaticRTP diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 80351b8..ee43324 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -3,10 +3,12 @@ package peer import ( "errors" "sync" + "time" "github.com/matrix-org/waterfall/pkg/common" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" ) var ( @@ -28,6 +30,7 @@ type Peer[ID comparable] struct { logger *logrus.Entry peerConnection *webrtc.PeerConnection sink *common.MessageSink[ID, MessageContent] + heartbeat chan HeartBeat dataChannelMutex sync.Mutex dataChannel *webrtc.DataChannel @@ -38,6 +41,7 @@ func NewPeer[ID comparable]( sdpOffer string, sink *common.MessageSink[ID, MessageContent], logger *logrus.Entry, + keepAliveDeadline time.Duration, ) (*Peer[ID], *webrtc.SessionDescription, error) { peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) if err != nil { @@ -49,6 +53,7 @@ func NewPeer[ID comparable]( logger: logger, peerConnection: peerConnection, sink: sink, + heartbeat: make(chan HeartBeat, common.UnboundedChannelSize), } peerConnection.OnTrack(peer.onRtpTrackReceived) @@ -63,6 +68,8 @@ func NewPeer[ID comparable]( if sdpAnswer, err := peer.ProcessSDPOffer(sdpOffer); err != nil { return nil, nil, err } else { + onDeadline := func() { peer.sink.Send(LeftTheCall{event.CallHangupKeepAliveTimeout}) } + startKeepAlive(keepAliveDeadline, peer.heartbeat, onDeadline) return peer, sdpAnswer, nil } } @@ -192,3 +199,10 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, return &answer, nil } + +// New heartbeat received (keep-alive message that is periodically sent by the remote peer). +// We need to update the last heartbeat time. If the peer is not active for too long, we will +// consider peer's connection as stalled and will close it. +func (p *Peer[ID]) ProcessHeartbeat() { + p.heartbeat <- HeartBeat{} +} diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 1a62358..2a97155 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -7,6 +7,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" ) // A callback that is called once we receive first RTP packets from a track, i.e. @@ -127,7 +128,7 @@ func (p *Peer[ID]) onConnectionStateChanged(state webrtc.PeerConnectionState) { switch state { case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: - p.sink.Send(LeftTheCall{}) + p.sink.Send(LeftTheCall{event.CallHangupUserHangup}) case webrtc.PeerConnectionStateConnected: p.sink.Send(JoinedTheCall{}) } From ec679069f4949e0b9d20d6d99052ed85a1963840 Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Fri, 2 Dec 2022 18:25:52 +0100 Subject: [PATCH 49/62] minor: rename `RunSync` -> `RunSyncing()` --- pkg/main.go | 2 +- pkg/signaling/client.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/main.go b/pkg/main.go index 05a8f51..bed8313 100644 --- a/pkg/main.go +++ b/pkg/main.go @@ -92,7 +92,7 @@ func main() { routerChannel := newRouter(matrixClient, config.Conference) // Start matrix client sync. This function will block until the sync fails. - matrixClient.RunSync(func(e *event.Event) { + matrixClient.RunSyncing(func(e *event.Event) { routerChannel <- e }) } diff --git a/pkg/signaling/client.go b/pkg/signaling/client.go index a16313f..c00e519 100644 --- a/pkg/signaling/client.go +++ b/pkg/signaling/client.go @@ -35,7 +35,7 @@ func NewMatrixClient(config Config) *MatrixClient { // Starts the Matrix client and connects to the homeserver, // Returns only when the sync with Matrix fails. -func (m *MatrixClient) RunSync(callback func(*event.Event)) { +func (m *MatrixClient) RunSyncing(callback func(*event.Event)) { syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) if !ok { logrus.Panic("Syncer is not DefaultSyncer") From 4ef47e0c845f67d28a2eb1ff63471661034a931e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Sat, 3 Dec 2022 09:14:01 +0100 Subject: [PATCH 50/62] Remove ugly RTCP handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/peer/webrtc.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go index 2a97155..8806689 100644 --- a/pkg/peer/webrtc.go +++ b/pkg/peer/webrtc.go @@ -3,9 +3,7 @@ package peer import ( "errors" "io" - "time" - "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -13,22 +11,6 @@ import ( // A callback that is called once we receive first RTP packets from a track, i.e. // we call this function each time a new track is received. func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - // Send a PLI on an interval so that the publisher is pushing a keyframe every rtcpPLIInterval. - // This can be less wasteful by processing incoming RTCP events, then we would emit a NACK/PLI - // when a viewer requests it. - // - // TODO: Add RTCP handling based on the PR from @SimonBrandner. - go func() { - ticker := time.NewTicker(time.Millisecond * 500) // every 500ms - for range ticker.C { - rtcp := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remoteTrack.SSRC())}} - if err := p.peerConnection.WriteRTCP(rtcp); err != nil && !errors.Is(err, io.ErrClosedPipe) { - p.logger.Errorf("Failed to send RTCP PLI: %v", err) - return - } - } - }() - // Create a local track, all our SFU clients that are subscribed to this // peer (publisher) wil be fed via this track. localTrack, err := webrtc.NewTrackLocalStaticRTP( From ce110b5cb0cc0349af56a92e94b404a9f66d4bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Sat, 3 Dec 2022 09:17:30 +0100 Subject: [PATCH 51/62] Implement RTCP forwarding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/matrix.go | 2 +- pkg/conference/participant.go | 8 ++++- pkg/conference/processor.go | 18 +++++++++- pkg/conference/state.go | 10 +++--- pkg/peer/messages.go | 13 ++++++++ pkg/peer/peer.go | 62 +++++++++++++++++++++++++++++++++-- 6 files changed, 103 insertions(+), 10 deletions(-) diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix.go index 61bd40c..691a1f5 100644 --- a/pkg/conference/matrix.go +++ b/pkg/conference/matrix.go @@ -65,7 +65,7 @@ func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent * logger: logger, remoteSessionID: inviteEvent.SenderSessionID, streamMetadata: inviteEvent.SDPStreamMetadata, - publishedTracks: make(map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP), + publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack), } c.participants[participantID] = participant diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 747a267..a668e37 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -2,6 +2,7 @@ package conference import ( "encoding/json" + "sync/atomic" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -19,6 +20,11 @@ type ParticipantID struct { CallID string } +type PublishedTrack struct { + Track *webrtc.TrackLocalStaticRTP + LastPLITimestamp atomic.Int64 +} + // Participant represents a participant in the conference. type Participant struct { id ParticipantID @@ -26,7 +32,7 @@ type Participant struct { peer *peer.Peer[ParticipantID] remoteSessionID id.SessionID streamMetadata event.CallSDPStreamMetadata - publishedTracks map[event.SFUTrackDescription]*webrtc.TrackLocalStaticRTP + publishedTracks map[event.SFUTrackDescription]PublishedTrack } func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go index 8a5b681..f5d4375 100644 --- a/pkg/conference/processor.go +++ b/pkg/conference/processor.go @@ -67,7 +67,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe return } - participant.publishedTracks[key] = msg.Track + participant.publishedTracks[key] = PublishedTrack{Track: msg.Track} c.resendMetadataToAllExcept(participant.id) case peer.PublishedTrackFailed: @@ -129,6 +129,22 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe Op: event.SFUOperationMetadata, Metadata: c.getAvailableStreamsFor(participant.id), }) + case peer.ForwardRTCP: + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { + participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.LastPLITimestamp.Load()) + } + } + } + case peer.PLISent: + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { + publishedTrack.LastPLITimestamp.Store(msg.Timestamp) + } + } + } default: c.logger.Errorf("Unknown message type: %T", msg) diff --git a/pkg/conference/state.go b/pkg/conference/state.go index 5acf97d..abc98e9 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -6,7 +6,6 @@ import ( "github.com/matrix-org/waterfall/pkg/signaling" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" ) @@ -56,7 +55,10 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { c.resendMetadataToAllExcept(participantID) // Remove the participant's tracks from all participants who might have subscribed to them. - obsoleteTracks := maps.Values(participant.publishedTracks) + obsoleteTracks := []*webrtc.TrackLocalStaticRTP{} + for _, publishedTrack := range participant.publishedTracks { + obsoleteTracks = append(obsoleteTracks, publishedTrack.Track) + } for _, otherParticipant := range c.participants { otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) } @@ -72,7 +74,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. // Now, find out which of published tracks belong to the streams for which we have metadata // available and construct a metadata map for a given participant based on that. for _, track := range participant.publishedTracks { - trackID, streamID := track.ID(), track.StreamID() + trackID, streamID := track.Track.ID(), track.Track.StreamID() if metadata, ok := streamsMetadata[streamID]; ok { metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} @@ -97,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt // Check if this participant has any of the tracks that we're looking for. for _, identifier := range identifiers { if track, ok := participant.publishedTracks[identifier]; ok { - tracks = append(tracks, track) + tracks = append(tracks, track.Track) } } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 6a05a82..e4b155d 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,6 +1,7 @@ package peer import ( + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" ) @@ -38,3 +39,15 @@ type DataChannelMessage struct { } type DataChannelAvailable struct{} + +type ForwardRTCP struct { + Packets []rtcp.Packet + StreamID string + TrackID string +} + +type PLISent struct { + Timestamp int64 + StreamID string + TrackID string +} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index ee43324..4692b2f 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -2,10 +2,12 @@ package peer import ( "errors" + "io" "sync" "time" "github.com/matrix-org/waterfall/pkg/common" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" "maunium.net/go/mautrix/event" @@ -22,6 +24,8 @@ var ( ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) +const minimalPLIInterval = time.Millisecond * 500 + // A wrapped representation of the peer connection (single peer in the call). // The peer gets information about the things happening outside via public methods // and informs the outside world about the things happening inside the peer by posting @@ -98,17 +102,69 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { // Before these packets are returned they are processed by interceptors. For things // like NACK this needs to be called. go func() { - rtcpBuf := make([]byte, 1500) for { - if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { - return + packets, _, err := rtpSender.ReadRTCP() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { + return + } + + p.logger.WithError(err).Warn("failed to read RTCP on track") } + + p.sink.Send(ForwardRTCP{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) } }() return nil } +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) { + packetsToSend := []rtcp.Packet{} + var mediaSSRC uint32 + for _, receiver := range p.peerConnection.GetReceivers() { + if receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID { + mediaSSRC = uint32(receiver.Track().SSRC()) + break + } + } + + for _, packet := range packets { + switch typedPacket := packet.(type) { + // We mung the packets here, so that the SSRCs match what the + // receiver expects: + // The media SSRC is the SSRC of the media about which the packet is + // reporting; therefore, we mung it to be the SSRC of the publishing + // participant's track. Without this, it would be SSRC of the SFU's + // track which isn't right + case *rtcp.PictureLossIndication: + // Since we sometimes spam the sender with PLIs, make sure we don't send + // them way too often + if time.Now().UnixNano()-lastPLITimestamp < minimalPLIInterval.Nanoseconds() { + continue + } + + p.sink.Send(PLISent{Timestamp: time.Now().UnixNano(), StreamID: streamID, TrackID: trackID}) + + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + case *rtcp.FullIntraRequest: + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + } + + packetsToSend = append(packetsToSend, packet) + } + + if len(packetsToSend) != 0 { + if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil { + if !errors.Is(err, io.ErrClosedPipe) { + p.logger.WithError(err).Warn("failed to write RTCP on track") + } + } + } +} + // Unsubscribes from the given list of tracks. func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. From ae5da43c94702e4769fb976938325d6851a9774d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Sat, 3 Dec 2022 09:59:36 +0100 Subject: [PATCH 52/62] Further refactor code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- .../data_channel_message_processor.go | 78 ++++++ ...{matrix.go => matrix_message_processor.go} | 0 pkg/conference/messsage_processor.go | 87 +++++++ pkg/conference/peer_message_processor.go | 134 ++++++++++ pkg/conference/processor.go | 236 ------------------ 5 files changed, 299 insertions(+), 236 deletions(-) create mode 100644 pkg/conference/data_channel_message_processor.go rename pkg/conference/{matrix.go => matrix_message_processor.go} (100%) create mode 100644 pkg/conference/messsage_processor.go create mode 100644 pkg/conference/peer_message_processor.go delete mode 100644 pkg/conference/processor.go diff --git a/pkg/conference/data_channel_message_processor.go b/pkg/conference/data_channel_message_processor.go new file mode 100644 index 0000000..4edf49e --- /dev/null +++ b/pkg/conference/data_channel_message_processor.go @@ -0,0 +1,78 @@ +package conference + +import ( + "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/event" +) + +// Handle the `SFUMessage` event from the DataChannel message. +func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received select request over DC") + + // Find tracks based on what we were asked for. + tracks := c.getTracks(msg.Start) + + // Let's check if we have all the tracks that we were asked for are there. + // If not, we will list which are not available (later on we must inform participant + // about it unless the participant retries it). + if len(tracks) != len(msg.Start) { + for _, expected := range msg.Start { + found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { + return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID + }) + + if found == -1 { + c.logger.Warnf("Track not found: %s", expected.TrackID) + } + } + } + + // Subscribe to the found tracks. + for _, track := range tracks { + if err := participant.peer.SubscribeTo(track); err != nil { + participant.logger.Errorf("Failed to subscribe to track: %v", err) + return + } + } +} + +func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP answer over DC") + + if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil { + participant.logger.Errorf("Failed to set SDP answer: %v", err) + return + } +} + +func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP offer over DC") + + answer, err := participant.peer.ProcessSDPOffer(msg.SDP) + if err != nil { + participant.logger.Errorf("Failed to set SDP offer: %v", err) + return + } + + participant.streamMetadata = msg.Metadata + + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationAnswer, + SDP: answer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processUnpublishDCMessage(participant *Participant) { + participant.logger.Info("Received unpublish over DC") +} + +func (c *Conference) processAliveDCMessage(participant *Participant) { + participant.peer.ProcessHeartbeat() +} + +func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) { + participant.streamMetadata = msg.Metadata + c.resendMetadataToAllExcept(participant.id) +} diff --git a/pkg/conference/matrix.go b/pkg/conference/matrix_message_processor.go similarity index 100% rename from pkg/conference/matrix.go rename to pkg/conference/matrix_message_processor.go diff --git a/pkg/conference/messsage_processor.go b/pkg/conference/messsage_processor.go new file mode 100644 index 0000000..9a54b09 --- /dev/null +++ b/pkg/conference/messsage_processor.go @@ -0,0 +1,87 @@ +package conference + +import ( + "errors" + + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "maunium.net/go/mautrix/event" +) + +// Listen on messages from incoming channels and process them. +// This is essentially the main loop of the conference. +// If this function returns, the conference is over. +func (c *Conference) processMessages() { + for { + select { + case msg := <-c.peerMessages: + c.processPeerMessage(msg) + case msg := <-c.matrixMessages.Channel: + c.processMatrixMessage(msg) + } + + // If there are no more participants, stop the conference. + if len(c.participants) == 0 { + c.logger.Info("No more participants, stopping the conference") + // Close the channel so that the sender can't push any messages. + unreadMessages := c.matrixMessages.Close() + + // Send the information that we ended to the owner and pass the message + // that we did not process (so that we don't drop it silently). + c.endNotifier.Notify(unreadMessages) + return + } + } +} + +// Process a message from a local peer. +func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { + participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) + if participant == nil { + return + } + + // Since Go does not support ADTs, we have to use a switch statement to + // determine the actual type of the message. + switch msg := message.Content.(type) { + case peer.JoinedTheCall: + c.processJoinedTheCallMessage(participant, msg) + case peer.LeftTheCall: + c.processLeftTheCallMessage(participant, msg) + case peer.NewTrackPublished: + c.processNewTrackPublishedMessage(participant, msg) + case peer.PublishedTrackFailed: + c.processPublishedTrackFailedMessage(participant, msg) + case peer.NewICECandidate: + c.processNewICECandidateMessage(participant, msg) + case peer.ICEGatheringComplete: + c.processICEGatheringCompleteMessage(participant, msg) + case peer.RenegotiationRequired: + c.processRenegotiationRequiredMessage(participant, msg) + case peer.DataChannelMessage: + c.processDataChannelMessage(participant, msg) + case peer.DataChannelAvailable: + c.processDataChannelAvailableMessage(participant, msg) + case peer.ForwardRTCP: + c.processForwardRTCPMessage(msg) + case peer.PLISent: + c.processPLISentMessage(msg) + default: + c.logger.Errorf("Unknown message type: %T", msg) + } +} + +func (c *Conference) processMatrixMessage(msg MatrixMessage) { + switch ev := msg.Content.(type) { + case *event.CallInviteEventContent: + c.onNewParticipant(msg.Sender, ev) + case *event.CallCandidatesEventContent: + c.onCandidates(msg.Sender, ev) + case *event.CallSelectAnswerEventContent: + c.onSelectAnswer(msg.Sender, ev) + case *event.CallHangupEventContent: + c.onHangup(msg.Sender, ev) + default: + c.logger.Errorf("Unexpected event type: %T", ev) + } +} diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go new file mode 100644 index 0000000..be4ec47 --- /dev/null +++ b/pkg/conference/peer_message_processor.go @@ -0,0 +1,134 @@ +package conference + +import ( + "encoding/json" + + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" +) + +func (c *Conference) processJoinedTheCallMessage(participant *Participant, message peer.JoinedTheCall) { + participant.logger.Info("Joined the call") +} + +func (c *Conference) processLeftTheCallMessage(participant *Participant, msg peer.LeftTheCall) { + participant.logger.Info("Left the call: %s", msg.Reason) + c.removeParticipant(participant.id) + c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) +} + +func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) { + participant.logger.Infof("Published new track: %s", msg.Track.ID()) + key := event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + } + + if _, ok := participant.publishedTracks[key]; ok { + c.logger.Errorf("Track already published: %v", key) + return + } + + participant.publishedTracks[key] = PublishedTrack{Track: msg.Track} + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) { + participant.logger.Infof("Failed published track: %s", msg.Track.ID()) + delete(participant.publishedTracks, event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + }) + + for _, otherParticipant := range c.participants { + if otherParticipant.id == participant.id { + continue + } + + otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) + } + + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processNewICECandidateMessage(participant *Participant, msg peer.NewICECandidate) { + participant.logger.Debug("Received a new local ICE candidate") + + // Convert WebRTC ICE candidate to Matrix ICE candidate. + jsonCandidate := msg.Candidate.ToJSON() + candidates := []event.CallCandidate{{ + Candidate: jsonCandidate.Candidate, + SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), + SDPMID: *jsonCandidate.SDPMid, + }} + c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) +} + +func (c *Conference) processICEGatheringCompleteMessage(participant *Participant, msg peer.ICEGatheringComplete) { + participant.logger.Info("Completed local ICE gathering") + + // Send an empty array of candidates to indicate that ICE gathering is complete. + c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) +} + +func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) { + participant.logger.Info("Started renegotiation") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationOffer, + SDP: msg.Offer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) { + participant.logger.Debug("Received data channel message") + var sfuMessage event.SFUMessage + if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { + c.logger.Errorf("Failed to unmarshal SFU message: %v", err) + return + } + + switch sfuMessage.Op { + case event.SFUOperationSelect: + c.processSelectDCMessage(participant, sfuMessage) + case event.SFUOperationAnswer: + c.processAnswerDCMessage(participant, sfuMessage) + case event.SFUOperationPublish: + c.processPublishDCMessage(participant, sfuMessage) + case event.SFUOperationUnpublish: + c.processUnpublishDCMessage(participant) + case event.SFUOperationAlive: + c.processAliveDCMessage(participant) + case event.SFUOperationMetadata: + c.processMetadataDCMessage(participant, sfuMessage) + } +} + +func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) { + participant.logger.Info("Connected data channel") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processForwardRTCPMessage(msg peer.ForwardRTCP) { + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { + participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.LastPLITimestamp.Load()) + } + } + } +} + +func (c *Conference) processPLISentMessage(msg peer.PLISent) { + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { + publishedTrack.LastPLITimestamp.Store(msg.Timestamp) + } + } + } +} diff --git a/pkg/conference/processor.go b/pkg/conference/processor.go deleted file mode 100644 index f5d4375..0000000 --- a/pkg/conference/processor.go +++ /dev/null @@ -1,236 +0,0 @@ -package conference - -import ( - "encoding/json" - "errors" - - "github.com/matrix-org/waterfall/pkg/common" - "github.com/matrix-org/waterfall/pkg/peer" - "github.com/pion/webrtc/v3" - "golang.org/x/exp/slices" - "maunium.net/go/mautrix/event" -) - -// Listen on messages from incoming channels and process them. -// This is essentially the main loop of the conference. -// If this function returns, the conference is over. -func (c *Conference) processMessages() { - for { - select { - case msg := <-c.peerMessages: - c.processPeerMessage(msg) - case msg := <-c.matrixMessages.Channel: - c.processMatrixMessage(msg) - } - - // If there are no more participants, stop the conference. - if len(c.participants) == 0 { - c.logger.Info("No more participants, stopping the conference") - // Close the channel so that the sender can't push any messages. - unreadMessages := c.matrixMessages.Close() - - // Send the information that we ended to the owner and pass the message - // that we did not process (so that we don't drop it silently). - c.endNotifier.Notify(unreadMessages) - return - } - } -} - -// Process a message from a local peer. -func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { - participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) - if participant == nil { - return - } - - // Since Go does not support ADTs, we have to use a switch statement to - // determine the actual type of the message. - switch msg := message.Content.(type) { - case peer.JoinedTheCall: - participant.logger.Info("Joined the call") - - case peer.LeftTheCall: - participant.logger.Info("Left the call: %s", msg.Reason) - c.removeParticipant(message.Sender) - c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) - - case peer.NewTrackPublished: - participant.logger.Infof("Published new track: %s", msg.Track.ID()) - key := event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - } - - if _, ok := participant.publishedTracks[key]; ok { - c.logger.Errorf("Track already published: %v", key) - return - } - - participant.publishedTracks[key] = PublishedTrack{Track: msg.Track} - c.resendMetadataToAllExcept(participant.id) - - case peer.PublishedTrackFailed: - participant.logger.Infof("Failed published track: %s", msg.Track.ID()) - delete(participant.publishedTracks, event.SFUTrackDescription{ - StreamID: msg.Track.StreamID(), - TrackID: msg.Track.ID(), - }) - - for _, otherParticipant := range c.participants { - if otherParticipant.id == participant.id { - continue - } - - otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) - } - - c.resendMetadataToAllExcept(participant.id) - - case peer.NewICECandidate: - participant.logger.Debug("Received a new local ICE candidate") - - // Convert WebRTC ICE candidate to Matrix ICE candidate. - jsonCandidate := msg.Candidate.ToJSON() - candidates := []event.CallCandidate{{ - Candidate: jsonCandidate.Candidate, - SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), - SDPMID: *jsonCandidate.SDPMid, - }} - c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) - - case peer.ICEGatheringComplete: - participant.logger.Info("Completed local ICE gathering") - - // Send an empty array of candidates to indicate that ICE gathering is complete. - c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) - - case peer.RenegotiationRequired: - participant.logger.Info("Started renegotiation") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationOffer, - SDP: msg.Offer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - - case peer.DataChannelMessage: - participant.logger.Debug("Received data channel message") - var sfuMessage event.SFUMessage - if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { - c.logger.Errorf("Failed to unmarshal SFU message: %v", err) - return - } - - c.handleDataChannelMessage(participant, sfuMessage) - - case peer.DataChannelAvailable: - participant.logger.Info("Connected data channel") - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationMetadata, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - case peer.ForwardRTCP: - for _, participant := range c.participants { - for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { - participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.LastPLITimestamp.Load()) - } - } - } - case peer.PLISent: - for _, participant := range c.participants { - for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { - publishedTrack.LastPLITimestamp.Store(msg.Timestamp) - } - } - } - - default: - c.logger.Errorf("Unknown message type: %T", msg) - } -} - -// Handle the `SFUMessage` event from the DataChannel message. -func (c *Conference) handleDataChannelMessage(participant *Participant, sfuMessage event.SFUMessage) { - switch sfuMessage.Op { - case event.SFUOperationSelect: - participant.logger.Info("Received select request over DC") - - // Find tracks based on what we were asked for. - tracks := c.getTracks(sfuMessage.Start) - - // Let's check if we have all the tracks that we were asked for are there. - // If not, we will list which are not available (later on we must inform participant - // about it unless the participant retries it). - if len(tracks) != len(sfuMessage.Start) { - for _, expected := range sfuMessage.Start { - found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { - return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID - }) - - if found == -1 { - c.logger.Warnf("Track not found: %s", expected.TrackID) - } - } - } - - // Subscribe to the found tracks. - for _, track := range tracks { - if err := participant.peer.SubscribeTo(track); err != nil { - participant.logger.Errorf("Failed to subscribe to track: %v", err) - return - } - } - - case event.SFUOperationAnswer: - participant.logger.Info("Received SDP answer over DC") - - if err := participant.peer.ProcessSDPAnswer(sfuMessage.SDP); err != nil { - participant.logger.Errorf("Failed to set SDP answer: %v", err) - return - } - - case event.SFUOperationPublish: - participant.logger.Info("Received SDP offer over DC") - - answer, err := participant.peer.ProcessSDPOffer(sfuMessage.SDP) - if err != nil { - participant.logger.Errorf("Failed to set SDP offer: %v", err) - return - } - - participant.streamMetadata = sfuMessage.Metadata - - participant.sendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: answer.SDP, - Metadata: c.getAvailableStreamsFor(participant.id), - }) - - case event.SFUOperationUnpublish: - participant.logger.Info("Received unpublish over DC") - - case event.SFUOperationAlive: - participant.peer.ProcessHeartbeat() - - case event.SFUOperationMetadata: - participant.streamMetadata = sfuMessage.Metadata - c.resendMetadataToAllExcept(participant.id) - } -} - -func (c *Conference) processMatrixMessage(msg MatrixMessage) { - switch ev := msg.Content.(type) { - case *event.CallInviteEventContent: - c.onNewParticipant(msg.Sender, ev) - case *event.CallCandidatesEventContent: - c.onCandidates(msg.Sender, ev) - case *event.CallSelectAnswerEventContent: - c.onSelectAnswer(msg.Sender, ev) - case *event.CallHangupEventContent: - c.onHangup(msg.Sender, ev) - default: - c.logger.Errorf("Unexpected event type: %T", ev) - } -} From 73a69470dfb80ab2cad8b5a0e82d6344a2b87ab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:12:15 +0100 Subject: [PATCH 53/62] Make `minimalPLIInterval` local MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/peer/peer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 4692b2f..07d480f 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -24,8 +24,6 @@ var ( ErrCantSubscribeToTrack = errors.New("can't subscribe to track") ) -const minimalPLIInterval = time.Millisecond * 500 - // A wrapped representation of the peer connection (single peer in the call). // The peer gets information about the things happening outside via public methods // and informs the outside world about the things happening inside the peer by posting @@ -120,6 +118,8 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { } func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) { + const minimalPLIInterval = time.Millisecond * 500 + packetsToSend := []rtcp.Packet{} var mediaSSRC uint32 for _, receiver := range p.peerConnection.GetReceivers() { From 5a78cca5bd2f1f8e3e1098d5e2e9fb9e0ee77df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:25:26 +0100 Subject: [PATCH 54/62] Fix typo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/peer/peer.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 07d480f..ae3af81 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -15,7 +15,7 @@ import ( var ( ErrCantCreatePeerConnection = errors.New("can't create peer connection") - ErrCantSetRemoteDecsription = errors.New("can't set remote description") + ErrCantSetRemoteDescription = errors.New("can't set remote description") ErrCantCreateAnswer = errors.New("can't create answer") ErrCantSetLocalDescription = errors.New("can't set local description") ErrCantCreateLocalDescription = errors.New("can't create local description") @@ -225,7 +225,7 @@ func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { }) if err != nil { p.logger.WithError(err).Error("failed to set remote description") - return ErrCantSetRemoteDecsription + return ErrCantSetRemoteDescription } return nil @@ -239,7 +239,7 @@ func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, }) if err != nil { p.logger.WithError(err).Error("failed to set remote description") - return nil, ErrCantSetRemoteDecsription + return nil, ErrCantSetRemoteDescription } answer, err := p.peerConnection.CreateAnswer(nil) From eecdde23892577ac19dafe3110077308dd654fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:31:10 +0100 Subject: [PATCH 55/62] Handle missing track MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/peer/peer.go | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index ae3af81..9eac8c4 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -10,6 +10,7 @@ import ( "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" ) @@ -22,6 +23,7 @@ var ( ErrDataChannelNotAvailable = errors.New("data channel is not available") ErrDataChannelNotReady = errors.New("data channel is not ready") ErrCantSubscribeToTrack = errors.New("can't subscribe to track") + ErrCantWriteRTCP = errors.New("can't write RTCP") ) // A wrapped representation of the peer connection (single peer in the call). @@ -117,16 +119,21 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { return nil } -func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) { +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) error { const minimalPLIInterval = time.Millisecond * 500 packetsToSend := []rtcp.Packet{} var mediaSSRC uint32 - for _, receiver := range p.peerConnection.GetReceivers() { - if receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID { - mediaSSRC = uint32(receiver.Track().SSRC()) - break - } + receivers := p.peerConnection.GetReceivers() + receiverIndex := slices.IndexFunc(receivers, func(receiver *webrtc.RTPReceiver) bool { + return receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID + }) + + if receiverIndex == -1 { + p.logger.Error("failed to find track to write RTCP on") + return ErrCantWriteRTCP + } else { + mediaSSRC = uint32(receivers[receiverIndex].Track().SSRC()) } for _, packet := range packets { @@ -159,10 +166,13 @@ func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID str if len(packetsToSend) != 0 { if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil { if !errors.Is(err, io.ErrClosedPipe) { - p.logger.WithError(err).Warn("failed to write RTCP on track") + p.logger.WithError(err).Error("failed to write RTCP on track") + return err } } } + + return nil } // Unsubscribes from the given list of tracks. From e208b4bc9bb614aec920cc052fb4a23ad52ebe25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:32:48 +0100 Subject: [PATCH 56/62] `ForwardRTCP` -> `RTCPReceived` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/messsage_processor.go | 2 +- pkg/conference/peer_message_processor.go | 2 +- pkg/peer/messages.go | 2 +- pkg/peer/peer.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/conference/messsage_processor.go b/pkg/conference/messsage_processor.go index 9a54b09..94de6a6 100644 --- a/pkg/conference/messsage_processor.go +++ b/pkg/conference/messsage_processor.go @@ -62,7 +62,7 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.processDataChannelMessage(participant, msg) case peer.DataChannelAvailable: c.processDataChannelAvailableMessage(participant, msg) - case peer.ForwardRTCP: + case peer.RTCPReceived: c.processForwardRTCPMessage(msg) case peer.PLISent: c.processPLISentMessage(msg) diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index be4ec47..da2ffac 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -113,7 +113,7 @@ func (c *Conference) processDataChannelAvailableMessage(participant *Participant }) } -func (c *Conference) processForwardRTCPMessage(msg peer.ForwardRTCP) { +func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index e4b155d..91ac734 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -40,7 +40,7 @@ type DataChannelMessage struct { type DataChannelAvailable struct{} -type ForwardRTCP struct { +type RTCPReceived struct { Packets []rtcp.Packet StreamID string TrackID string diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 9eac8c4..9540c5e 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -112,7 +112,7 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { p.logger.WithError(err).Warn("failed to read RTCP on track") } - p.sink.Send(ForwardRTCP{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) + p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) } }() From 70a4edec62fa4899d5f0da83780312d0da0aa6e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:34:37 +0100 Subject: [PATCH 57/62] Decapitalize MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/participant.go | 4 ++-- pkg/conference/peer_message_processor.go | 10 +++++----- pkg/conference/state.go | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index a668e37..0ac696e 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -21,8 +21,8 @@ type ParticipantID struct { } type PublishedTrack struct { - Track *webrtc.TrackLocalStaticRTP - LastPLITimestamp atomic.Int64 + track *webrtc.TrackLocalStaticRTP + lastPLITimestamp atomic.Int64 } // Participant represents a participant in the conference. diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index da2ffac..9098bf8 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -30,7 +30,7 @@ func (c *Conference) processNewTrackPublishedMessage(participant *Participant, m return } - participant.publishedTracks[key] = PublishedTrack{Track: msg.Track} + participant.publishedTracks[key] = PublishedTrack{track: msg.Track} c.resendMetadataToAllExcept(participant.id) } @@ -116,8 +116,8 @@ func (c *Conference) processDataChannelAvailableMessage(participant *Participant func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { - participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.LastPLITimestamp.Load()) + if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { + participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp.Load()) } } } @@ -126,8 +126,8 @@ func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { func (c *Conference) processPLISentMessage(msg peer.PLISent) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.Track.StreamID() == msg.StreamID && publishedTrack.Track.ID() == msg.TrackID { - publishedTrack.LastPLITimestamp.Store(msg.Timestamp) + if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { + publishedTrack.lastPLITimestamp.Store(msg.Timestamp) } } } diff --git a/pkg/conference/state.go b/pkg/conference/state.go index abc98e9..de43c00 100644 --- a/pkg/conference/state.go +++ b/pkg/conference/state.go @@ -57,7 +57,7 @@ func (c *Conference) removeParticipant(participantID ParticipantID) { // Remove the participant's tracks from all participants who might have subscribed to them. obsoleteTracks := []*webrtc.TrackLocalStaticRTP{} for _, publishedTrack := range participant.publishedTracks { - obsoleteTracks = append(obsoleteTracks, publishedTrack.Track) + obsoleteTracks = append(obsoleteTracks, publishedTrack.track) } for _, otherParticipant := range c.participants { otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) @@ -74,7 +74,7 @@ func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event. // Now, find out which of published tracks belong to the streams for which we have metadata // available and construct a metadata map for a given participant based on that. for _, track := range participant.publishedTracks { - trackID, streamID := track.Track.ID(), track.Track.StreamID() + trackID, streamID := track.track.ID(), track.track.StreamID() if metadata, ok := streamsMetadata[streamID]; ok { metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} @@ -99,7 +99,7 @@ func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrt // Check if this participant has any of the tracks that we're looking for. for _, identifier := range identifiers { if track, ok := participant.publishedTracks[identifier]; ok { - tracks = append(tracks, track.Track) + tracks = append(tracks, track.track) } } } From fd545d0e72645d61ee906f1e55cb9200879ceae2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:36:54 +0100 Subject: [PATCH 58/62] Explain `lastPLITimestamp` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/participant.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 0ac696e..449a377 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -21,7 +21,9 @@ type ParticipantID struct { } type PublishedTrack struct { - track *webrtc.TrackLocalStaticRTP + track *webrtc.TrackLocalStaticRTP + // The time when we sent the last PLI to the sender. We store this to avoid + // spamming the sender. lastPLITimestamp atomic.Int64 } From 2649e1f8eef660b3b912ba83af4bce57a161d056 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:40:36 +0100 Subject: [PATCH 59/62] Use `time.Time` as type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/participant.go | 4 ++-- pkg/conference/peer_message_processor.go | 4 ++-- pkg/peer/messages.go | 4 +++- pkg/peer/peer.go | 6 +++--- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go index 449a377..59c2c9d 100644 --- a/pkg/conference/participant.go +++ b/pkg/conference/participant.go @@ -2,7 +2,7 @@ package conference import ( "encoding/json" - "sync/atomic" + "time" "github.com/matrix-org/waterfall/pkg/peer" "github.com/matrix-org/waterfall/pkg/signaling" @@ -24,7 +24,7 @@ type PublishedTrack struct { track *webrtc.TrackLocalStaticRTP // The time when we sent the last PLI to the sender. We store this to avoid // spamming the sender. - lastPLITimestamp atomic.Int64 + lastPLITimestamp time.Time } // Participant represents a participant in the conference. diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index 9098bf8..1b5d949 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -117,7 +117,7 @@ func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { - participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp.Load()) + participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) } } } @@ -127,7 +127,7 @@ func (c *Conference) processPLISentMessage(msg peer.PLISent) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { - publishedTrack.lastPLITimestamp.Store(msg.Timestamp) + publishedTrack.lastPLITimestamp = msg.Timestamp } } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 91ac734..08b0f20 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,6 +1,8 @@ package peer import ( + "time" + "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" @@ -47,7 +49,7 @@ type RTCPReceived struct { } type PLISent struct { - Timestamp int64 + Timestamp time.Time StreamID string TrackID string } diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 9540c5e..0909504 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -119,7 +119,7 @@ func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { return nil } -func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp int64) error { +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp time.Time) error { const minimalPLIInterval = time.Millisecond * 500 packetsToSend := []rtcp.Packet{} @@ -147,11 +147,11 @@ func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID str case *rtcp.PictureLossIndication: // Since we sometimes spam the sender with PLIs, make sure we don't send // them way too often - if time.Now().UnixNano()-lastPLITimestamp < minimalPLIInterval.Nanoseconds() { + if time.Now().UnixNano()-lastPLITimestamp.UnixNano() < minimalPLIInterval.Nanoseconds() { continue } - p.sink.Send(PLISent{Timestamp: time.Now().UnixNano(), StreamID: streamID, TrackID: trackID}) + p.sink.Send(PLISent{Timestamp: time.Now(), StreamID: streamID, TrackID: trackID}) typedPacket.MediaSSRC = mediaSSRC packetsToSend = append(packetsToSend, typedPacket) From 0369070153486b0abae3375e04cafbf123d52ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:57:59 +0100 Subject: [PATCH 60/62] Simplify `lastPLITimestamp` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/peer_message_processor.go | 16 +++++----------- pkg/peer/messages.go | 8 -------- pkg/peer/peer.go | 2 -- 3 files changed, 5 insertions(+), 21 deletions(-) diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go index 1b5d949..a815972 100644 --- a/pkg/conference/peer_message_processor.go +++ b/pkg/conference/peer_message_processor.go @@ -2,6 +2,7 @@ package conference import ( "encoding/json" + "time" "github.com/matrix-org/waterfall/pkg/peer" "github.com/pion/webrtc/v3" @@ -117,17 +118,10 @@ func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { for _, participant := range c.participants { for _, publishedTrack := range participant.publishedTracks { if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { - participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) - } - } - } -} - -func (c *Conference) processPLISentMessage(msg peer.PLISent) { - for _, participant := range c.participants { - for _, publishedTrack := range participant.publishedTracks { - if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { - publishedTrack.lastPLITimestamp = msg.Timestamp + err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) + if err == nil { + publishedTrack.lastPLITimestamp = time.Now() + } } } } diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go index 08b0f20..0593d72 100644 --- a/pkg/peer/messages.go +++ b/pkg/peer/messages.go @@ -1,8 +1,6 @@ package peer import ( - "time" - "github.com/pion/rtcp" "github.com/pion/webrtc/v3" "maunium.net/go/mautrix/event" @@ -47,9 +45,3 @@ type RTCPReceived struct { StreamID string TrackID string } - -type PLISent struct { - Timestamp time.Time - StreamID string - TrackID string -} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go index 0909504..a226633 100644 --- a/pkg/peer/peer.go +++ b/pkg/peer/peer.go @@ -151,8 +151,6 @@ func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID str continue } - p.sink.Send(PLISent{Timestamp: time.Now(), StreamID: streamID, TrackID: trackID}) - typedPacket.MediaSSRC = mediaSSRC packetsToSend = append(packetsToSend, typedPacket) case *rtcp.FullIntraRequest: From dc0931838960757595a7772044e5ed5f12a0e07a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Brandner?= Date: Mon, 5 Dec 2022 16:59:22 +0100 Subject: [PATCH 61/62] Remove leftover MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Šimon Brandner --- pkg/conference/messsage_processor.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/conference/messsage_processor.go b/pkg/conference/messsage_processor.go index 94de6a6..3865aac 100644 --- a/pkg/conference/messsage_processor.go +++ b/pkg/conference/messsage_processor.go @@ -64,8 +64,6 @@ func (c *Conference) processPeerMessage(message common.Message[ParticipantID, pe c.processDataChannelAvailableMessage(participant, msg) case peer.RTCPReceived: c.processForwardRTCPMessage(msg) - case peer.PLISent: - c.processPLISentMessage(msg) default: c.logger.Errorf("Unknown message type: %T", msg) } From 537f4c003b96e6344f31c09c44737e3e0706902d Mon Sep 17 00:00:00 2001 From: Daniel Abramov Date: Mon, 5 Dec 2022 17:50:33 +0100 Subject: [PATCH 62/62] general: resolve leftovers after rebase --- pkg/publisher.go | 195 ---------------------------------------------- pkg/subscriber.go | 136 -------------------------------- 2 files changed, 331 deletions(-) delete mode 100644 pkg/publisher.go delete mode 100644 pkg/subscriber.go diff --git a/pkg/publisher.go b/pkg/publisher.go deleted file mode 100644 index 3058860..0000000 --- a/pkg/publisher.go +++ /dev/null @@ -1,195 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "io" - "sync" - "sync/atomic" - "time" - - "github.com/pion/rtcp" - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix/event" -) - -const ( - minimalPLIInterval = time.Millisecond * 500 - bufferSize = 1500 -) - -type Publisher struct { - Track *webrtc.TrackRemote - Call *Call - - mutex sync.RWMutex - logger *logrus.Entry - subscribers []*Subscriber - - lastPLI atomic.Int64 -} - -func NewPublisher( - track *webrtc.TrackRemote, - call *Call, -) *Publisher { - publisher := new(Publisher) - - publisher.Track = track - publisher.Call = call - - publisher.subscribers = []*Subscriber{} - publisher.logger = call.logger.WithFields(logrus.Fields{ - "track_id": track.ID(), - "track_kind": track.Kind(), - "stream_id": track.StreamID(), - }) - - call.mutex.Lock() - call.Publishers = append(call.Publishers, publisher) - call.mutex.Unlock() - - go publisher.WriteToSubscribers() - - publisher.logger.Info("published track") - - return publisher -} - -func (p *Publisher) Subscribe(call *Call) { - subscriber := NewSubscriber(call) - subscriber.Subscribe(p) - p.AddSubscriber(subscriber) -} - -func (p *Publisher) Stop() { - removed := p.Call.RemovePublisher(p) - - if len(p.subscribers) == 0 && !removed { - return - } - - for _, subscriber := range p.subscribers { - subscriber.Unsubscribe() - p.RemoveSubscriber(subscriber) - } - - p.logger.Info("unpublished track") -} - -func (p *Publisher) AddSubscriber(subscriber *Subscriber) { - p.mutex.Lock() - defer p.mutex.Unlock() - p.subscribers = append(p.subscribers, subscriber) -} - -func (p *Publisher) RemoveSubscriber(toDelete *Subscriber) { - newSubscribers := []*Subscriber{} - - p.mutex.Lock() - for _, subscriber := range p.subscribers { - if subscriber != toDelete { - newSubscribers = append(newSubscribers, subscriber) - } - } - - p.subscribers = newSubscribers - p.mutex.Unlock() -} - -func (p *Publisher) Matches(trackDescription event.SFUTrackDescription) bool { - if p.Track.ID() != trackDescription.TrackID { - return false - } - - if p.Track.StreamID() != trackDescription.StreamID { - return false - } - - return true -} - -func (p *Publisher) WriteRTCP(packets []rtcp.Packet) { - packetsToSend := []rtcp.Packet{} - readSSRC := uint32(p.Track.SSRC()) - - for _, packet := range packets { - switch typedPacket := packet.(type) { - // We mung the packets here, so that the SSRCs match what the - // receiver expects: - // The media SSRC is the SSRC of the media about which the packet is - // reporting; therefore, we mung it to be the SSRC of the publishing - // participant's track. Without this, it would be SSRC of the SFU's - // track which isn't right - case *rtcp.PictureLossIndication: - // Since we sometimes spam the sender with PLIs, make sure we don't send - // them way too often - if time.Now().UnixNano()-p.lastPLI.Load() < minimalPLIInterval.Nanoseconds() { - continue - } - - p.lastPLI.Store(time.Now().UnixNano()) - - typedPacket.MediaSSRC = readSSRC - packetsToSend = append(packetsToSend, typedPacket) - case *rtcp.FullIntraRequest: - typedPacket.MediaSSRC = readSSRC - packetsToSend = append(packetsToSend, typedPacket) - } - - packetsToSend = append(packetsToSend, packet) - } - - if len(packetsToSend) != 0 { - if err := p.Call.PeerConnection.WriteRTCP(packetsToSend); err != nil { - if !errors.Is(err, io.ErrClosedPipe) { - p.logger.WithError(err).Warn("failed to write RTCP on track") - } - } - } -} - -func (p *Publisher) WriteToSubscribers() { - buff := make([]byte, bufferSize) - - for { - index, _, err := p.Track.Read(buff) - if err != nil { - if errors.Is(err, io.EOF) { - p.Stop() - return - } - - p.logger.WithError(err).Warn("failed to read track") - } - - for _, subscriber := range p.subscribers { - if _, err = subscriber.Track.Write(buff[:index]); err != nil { - if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - subscriber.Unsubscribe() - p.RemoveSubscriber(subscriber) - - return - } - - p.logger.WithError(err).Warn("failed to write to track") - } - } - } -} diff --git a/pkg/subscriber.go b/pkg/subscriber.go deleted file mode 100644 index fbe9235..0000000 --- a/pkg/subscriber.go +++ /dev/null @@ -1,136 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "io" - "sync" - - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" -) - -type Subscriber struct { - Track *webrtc.TrackLocalStaticRTP - - mutex sync.RWMutex - logger *logrus.Entry - call *Call - sender *webrtc.RTPSender - publisher *Publisher -} - -func NewSubscriber(call *Call) *Subscriber { - subscriber := new(Subscriber) - - subscriber.call = call - subscriber.logger = call.logger - - call.mutex.Lock() - call.Subscribers = append(call.Subscribers, subscriber) - call.mutex.Unlock() - - return subscriber -} - -func (s *Subscriber) initLoggingWithTrack(track *webrtc.TrackRemote) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.logger = s.call.logger.WithFields(logrus.Fields{ - "track_id": (*track).ID(), - "track_kind": (*track).Kind(), - "stream_id": (*track).StreamID(), - }) -} - -func (s *Subscriber) Subscribe(publisher *Publisher) { - s.initLoggingWithTrack(publisher.Track) - - if s.publisher != nil { - s.logger.Error("cannot subscribe, if we already are") - } - - track, err := webrtc.NewTrackLocalStaticRTP( - publisher.Track.Codec().RTPCodecCapability, - publisher.Track.ID(), - publisher.Track.StreamID(), - ) - if err != nil { - s.logger.WithError(err).Error("failed to create local static RTP track") - } - - sender, err := s.call.PeerConnection.AddTrack(track) - if err != nil { - s.logger.WithError(err).Error("failed to add track to peer connection") - } - - s.mutex.Lock() - s.Track = track - s.sender = sender - s.publisher = publisher - s.mutex.Unlock() - - if s.Track.Kind() == webrtc.RTPCodecTypeVideo { - go s.forwardRTCP() - } - - publisher.AddSubscriber(s) - - s.logger.Info("subscribed") -} - -func (s *Subscriber) Unsubscribe() { - if s.publisher == nil { - return - } - - if s.call.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateClosed { - err := s.call.PeerConnection.RemoveTrack(s.sender) - if err != nil { - s.logger.WithError(err).Error("failed to remove track") - } - } - - s.call.RemoveSubscriber(s) - - s.mutex.Lock() - s.publisher = nil - s.mutex.Unlock() - - s.logger.Info("unsubscribed") -} - -func (s *Subscriber) forwardRTCP() { - for { - // If we unsubscribed, stop forwarding RTCP packets - if s.publisher == nil { - return - } - - packets, _, err := s.sender.ReadRTCP() - if err != nil { - if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - return - } - - s.logger.WithError(err).Warn("failed to read RTCP on track") - } - - s.publisher.WriteRTCP(packets) - } -}