diff --git a/.golangci.yaml b/.golangci.yaml index a27073d..1ae85b5 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -29,4 +29,6 @@ linters: - gomnd # we use status code numbers and for our use case it's not practical - godox # we have TODOs at this stage of the project, enable in future - forbidigo # we use things like fmt.Printf for debugging, enable in future + - wsl # somehow this plugin causes more harm than use as it enables lots of things to be configured without causing spaghetti-code (grouping similar things together) + - nlreturn # not always practical, it was disabled before strict lints were introduced, then added, now it's clear why it was disabled at the first place :) fast: true diff --git a/Dockerfile b/Dockerfile index b3bc9f6..b58d0f2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,9 @@ COPY go.sum ./ # source code do not invalidate our downloaded layer. RUN go mod download -COPY ./src ./src +COPY ./pkg ./pkg -RUN go build -o /waterfall ./src +RUN go build -o /waterfall ./pkg ## diff --git a/config.yaml.sample b/config.yaml.sample index 0b1b4df..e53b2e1 100644 --- a/config.yaml.sample +++ b/config.yaml.sample @@ -1,4 +1,6 @@ -homeserverurl: "http://localhost:8008" -userid: "@sfu:shadowfax" -accesstoken: "..." -timeout: 30 +matrix: + homeserverurl: "http://localhost:8008" + userid: "@sfu:shadowfax" + accesstoken: "..." +conference: + timeout: 30 diff --git a/docker-compose.yaml b/docker-compose.yaml index 90ffc6d..83d75e0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -9,7 +9,9 @@ services: environment: # Set the `CONFIG` to the configuration you want. CONFIG: | - homeserverurl: "http://localhost:8008" - userid: "@sfu:shadowfax" - accesstoken: "..." - timeout: 30 + matrix: + homeserverurl: "http://localhost:8008" + userid: "@sfu:shadowfax" + accesstoken: "..." + conference: + timeout: 30 diff --git a/go.mod b/go.mod index 67dcfa9..adaf0b3 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require github.com/pion/webrtc/v3 v3.1.31 require ( github.com/pion/rtcp v1.2.9 github.com/sirupsen/logrus v1.9.0 + golang.org/x/exp v0.0.0-20221114191408-850992195362 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mautrix v0.11.0 ) @@ -34,7 +35,7 @@ require ( github.com/tidwall/sjson v1.2.4 // indirect golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e // indirect - golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect + golang.org/x/sys v0.1.0 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect ) diff --git a/go.sum b/go.sum index 17aeab1..e02475c 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20220131195533-30dcbda58838/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20221114191408-850992195362 h1:NoHlPRbyl1VFI6FjwHtPQCN7wAMXI6cKcqrmXhOOfBQ= +golang.org/x/exp v0.0.0-20221114191408-850992195362/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -134,8 +136,9 @@ golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/pkg/common/channel.go b/pkg/common/channel.go new file mode 100644 index 0000000..5c702f1 --- /dev/null +++ b/pkg/common/channel.go @@ -0,0 +1,70 @@ +package common + +import "sync/atomic" + +// In Go, unbounded channel means something different than what it means in Rust. +// I.e. unlike Rust, "unbounded" in Go means that the channel has **no buffer**, +// meaning that each attempt to send will block the channel until the receiver +// reads it. Majority of primitives here in `waterfall` are designed under assumption +// that sending is not blocking. +const UnboundedChannelSize = 128 + +// Creates a new channel, returns two counterparts of it where one can only send and another can only receive. +// Unlike traditional Go channels, these allow the receiver to mark the channel as closed which would then fail +// to send any messages to the channel over `Send“. +func NewChannel[M any]() (Sender[M], Receiver[M]) { + channel := make(chan M, UnboundedChannelSize) + closed := &atomic.Bool{} + sender := Sender[M]{channel, closed} + receiver := Receiver[M]{channel, closed} + return sender, receiver +} + +// Sender counterpart of the channel. +type Sender[M any] struct { + // The channel itself. + channel chan<- M + // Atomic variable that indicates whether the channel is closed. + receiverClosed *atomic.Bool +} + +// Tries to send a message if the channel is not closed. +// Returns the message back if the channel is closed. +func (s *Sender[M]) Send(message M) *M { + if !s.receiverClosed.Load() { + s.channel <- message + return nil + } else { + return &message + } +} + +// The receiver counterpart of the channel. +type Receiver[M any] struct { + // The channel itself. It's public, so that we can combine it in `select` statements. + Channel <-chan M + // Atomic variable that indicates whether the channel is closed. + receiverClosed *atomic.Bool +} + +// Marks the channel as closed, which means that no messages could be sent via this channel. +// Any attempt to send a message would result in an error. This is similar to closing the +// channel except that we don't close the underlying channel (since in Go receivers can't +// close the channel). +// +// This function reads (in a non-blocking way) all pending messages until blocking. Otherwise, +// they will stay forver in a channel and get lost. +func (r *Receiver[M]) Close() []M { + r.receiverClosed.Store(true) + + messages := make([]M, 0) + for { + msg, ok := <-r.Channel + if !ok { + break + } + messages = append(messages, msg) + } + + return messages +} diff --git a/pkg/common/message_sink.go b/pkg/common/message_sink.go new file mode 100644 index 0000000..1c7604f --- /dev/null +++ b/pkg/common/message_sink.go @@ -0,0 +1,61 @@ +package common + +import ( + "errors" + "sync/atomic" +) + +// MessageSink is a helper struct that allows to send messages to a message sink. +// The MessageSink abstracts the message sink which has a certain sender, so that +// the sender does not have to be specified every time a message is sent. +// At the same it guarantees that the caller can't alter the `sender`, which means that +// the sender can't impersonate another sender (and we guarantee this on a compile-time). +type MessageSink[SenderType comparable, MessageType any] struct { + // The sender of the messages. This is useful for multiple-producer-single-consumer scenarios. + sender SenderType + // The message sink to which the messages are sent. + messageSink chan<- Message[SenderType, MessageType] + // Atomic variable that indicates whether the message sink is sealed. + // This is used to prevent sending messages to a sealed message sink. + // The variable is atomic because it may be accessed from multiple goroutines. + sealed atomic.Bool +} + +// Creates a new MessageSink. The function is generic allowing us to use it for multiple use cases. +func NewMessageSink[S comparable, M any](sender S, messageSink chan<- Message[S, M]) *MessageSink[S, M] { + return &MessageSink[S, M]{ + sender: sender, + messageSink: messageSink, + } +} + +// Sends a message to the message sink. +func (s *MessageSink[S, M]) Send(message M) error { + if s.sealed.Load() { + return errors.New("The channel is sealed, you can't send any messages over it") + } + + s.messageSink <- Message[S, M]{ + Sender: s.sender, + Content: message, + } + + return nil +} + +// Seals the channel, which means that no messages could be sent via this channel. +// Any attempt to send a message would result in an error. This is similar to closing the +// channel except that we don't close the underlying channel (since there might be other +// senders that may want to use it). +func (s *MessageSink[S, M]) Seal() { + s.sealed.Store(true) +} + +// Messages that are sent from the peer to the conference in order to communicate with other peers. +// Since each peer is isolated from others, it can't influence the state of other peers directly. +type Message[SenderType comparable, MessageType any] struct { + // The sender of the message. + Sender SenderType + // The content of the message. + Content MessageType +} diff --git a/pkg/conference/config.go b/pkg/conference/config.go new file mode 100644 index 0000000..a239d60 --- /dev/null +++ b/pkg/conference/config.go @@ -0,0 +1,8 @@ +package conference + +// Configuration for the group conferences (calls). +type Config struct { + // Keep-alive timeout for WebRTC connections. If no keep-alive has been received + // from the client for this duration, the connection is considered dead (in seconds). + KeepAliveTimeout int `yaml:"timeout"` +} diff --git a/pkg/conference/data_channel_message_processor.go b/pkg/conference/data_channel_message_processor.go new file mode 100644 index 0000000..4edf49e --- /dev/null +++ b/pkg/conference/data_channel_message_processor.go @@ -0,0 +1,78 @@ +package conference + +import ( + "github.com/pion/webrtc/v3" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/event" +) + +// Handle the `SFUMessage` event from the DataChannel message. +func (c *Conference) processSelectDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received select request over DC") + + // Find tracks based on what we were asked for. + tracks := c.getTracks(msg.Start) + + // Let's check if we have all the tracks that we were asked for are there. + // If not, we will list which are not available (later on we must inform participant + // about it unless the participant retries it). + if len(tracks) != len(msg.Start) { + for _, expected := range msg.Start { + found := slices.IndexFunc(tracks, func(track *webrtc.TrackLocalStaticRTP) bool { + return track.StreamID() == expected.StreamID && track.ID() == expected.TrackID + }) + + if found == -1 { + c.logger.Warnf("Track not found: %s", expected.TrackID) + } + } + } + + // Subscribe to the found tracks. + for _, track := range tracks { + if err := participant.peer.SubscribeTo(track); err != nil { + participant.logger.Errorf("Failed to subscribe to track: %v", err) + return + } + } +} + +func (c *Conference) processAnswerDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP answer over DC") + + if err := participant.peer.ProcessSDPAnswer(msg.SDP); err != nil { + participant.logger.Errorf("Failed to set SDP answer: %v", err) + return + } +} + +func (c *Conference) processPublishDCMessage(participant *Participant, msg event.SFUMessage) { + participant.logger.Info("Received SDP offer over DC") + + answer, err := participant.peer.ProcessSDPOffer(msg.SDP) + if err != nil { + participant.logger.Errorf("Failed to set SDP offer: %v", err) + return + } + + participant.streamMetadata = msg.Metadata + + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationAnswer, + SDP: answer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processUnpublishDCMessage(participant *Participant) { + participant.logger.Info("Received unpublish over DC") +} + +func (c *Conference) processAliveDCMessage(participant *Participant) { + participant.peer.ProcessHeartbeat() +} + +func (c *Conference) processMetadataDCMessage(participant *Participant, msg event.SFUMessage) { + participant.streamMetadata = msg.Metadata + c.resendMetadataToAllExcept(participant.id) +} diff --git a/pkg/conference/matrix_message_processor.go b/pkg/conference/matrix_message_processor.go new file mode 100644 index 0000000..691a1f5 --- /dev/null +++ b/pkg/conference/matrix_message_processor.go @@ -0,0 +1,127 @@ +package conference + +import ( + "time" + + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" +) + +type MessageContent interface{} + +type MatrixMessage struct { + Sender ParticipantID + Content MessageContent + RawEvent *event.Event +} + +// New participant tries to join the conference. +func (c *Conference) onNewParticipant(participantID ParticipantID, inviteEvent *event.CallInviteEventContent) error { + logger := c.logger.WithFields(logrus.Fields{ + "user_id": participantID.UserID, + "device_id": participantID.DeviceID, + }) + + logger.Info("Incoming call invite") + + // As per MSC3401, when the `session_id` field changes from an incoming `m.call.member` event, + // any existing calls from this device in this call should be terminated. + if participant := c.participants[participantID]; participant != nil { + if participant.remoteSessionID == inviteEvent.SenderSessionID { + c.logger.Errorf("Found existing participant with equal DeviceID and SessionID") + } else { + c.removeParticipant(participantID) + } + } + + participant := c.participants[participantID] + var sdpAnswer *webrtc.SessionDescription + + // If participant exists still exists, then it means that the client does not behave properly. + // In this case we treat this new invitation as a new SDP offer. Otherwise, we create a new one. + if participant != nil { + answer, err := participant.peer.ProcessSDPOffer(inviteEvent.Offer.SDP) + if err != nil { + logger.WithError(err).Errorf("Failed to process SDP offer") + return err + } + sdpAnswer = answer + } else { + messageSink := common.NewMessageSink(participantID, c.peerMessages) + + keepAliveDeadline := time.Duration(c.config.KeepAliveTimeout) * time.Second + peer, answer, err := peer.NewPeer(inviteEvent.Offer.SDP, messageSink, logger, keepAliveDeadline) + if err != nil { + logger.WithError(err).Errorf("Failed to process SDP offer") + return err + } + + participant = &Participant{ + id: participantID, + peer: peer, + logger: logger, + remoteSessionID: inviteEvent.SenderSessionID, + streamMetadata: inviteEvent.SDPStreamMetadata, + publishedTracks: make(map[event.SFUTrackDescription]PublishedTrack), + } + + c.participants[participantID] = participant + sdpAnswer = answer + } + + // Send the answer back to the remote peer. + recipient := participant.asMatrixRecipient() + streamMetadata := c.getAvailableStreamsFor(participantID) + participant.logger.WithField("sdpAnswer", sdpAnswer.SDP).Debug("Sending SDP answer") + c.signaling.SendSDPAnswer(recipient, streamMetadata, sdpAnswer.SDP) + return nil +} + +// Process new ICE candidates received from Matrix signaling (from the remote peer) and forward them to +// our internal peer connection. +func (c *Conference) onCandidates(participantID ParticipantID, ev *event.CallCandidatesEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + participant.logger.Info("Received remote ICE candidates") + + // Convert the candidates to the WebRTC format. + candidates := make([]webrtc.ICECandidateInit, len(ev.Candidates)) + for i, candidate := range ev.Candidates { + SDPMLineIndex := uint16(candidate.SDPMLineIndex) + candidates[i] = webrtc.ICECandidateInit{ + Candidate: candidate.Candidate, + SDPMid: &candidate.SDPMID, + SDPMLineIndex: &SDPMLineIndex, + UsernameFragment: new(string), + } + } + + participant.peer.ProcessNewRemoteCandidates(candidates) + } +} + +// Process an acknowledgement from the remote peer that the SDP answer has been received +// and that the call can now proceed. +func (c *Conference) onSelectAnswer(participantID ParticipantID, ev *event.CallSelectAnswerEventContent) { + if participant := c.getParticipant(participantID, nil); participant != nil { + participant.logger.Info("Received remote answer selection") + + if ev.SelectedPartyID != string(c.signaling.DeviceID()) { + c.logger.WithFields(logrus.Fields{ + "device_id": ev.SelectedPartyID, + "user_id": participantID, + }).Errorf("Call was answered on a different device, kicking this peer") + c.removeParticipant(participantID) + } + } +} + +// Process a message from the remote peer telling that it wants to hang up the call. +func (c *Conference) onHangup(participantID ParticipantID, ev *event.CallHangupEventContent) { + if participant := c.participants[participantID]; participant != nil { + participant.logger.Info("Received remote hangup") + c.removeParticipant(participantID) + } +} diff --git a/pkg/conference/messsage_processor.go b/pkg/conference/messsage_processor.go new file mode 100644 index 0000000..3865aac --- /dev/null +++ b/pkg/conference/messsage_processor.go @@ -0,0 +1,85 @@ +package conference + +import ( + "errors" + + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "maunium.net/go/mautrix/event" +) + +// Listen on messages from incoming channels and process them. +// This is essentially the main loop of the conference. +// If this function returns, the conference is over. +func (c *Conference) processMessages() { + for { + select { + case msg := <-c.peerMessages: + c.processPeerMessage(msg) + case msg := <-c.matrixMessages.Channel: + c.processMatrixMessage(msg) + } + + // If there are no more participants, stop the conference. + if len(c.participants) == 0 { + c.logger.Info("No more participants, stopping the conference") + // Close the channel so that the sender can't push any messages. + unreadMessages := c.matrixMessages.Close() + + // Send the information that we ended to the owner and pass the message + // that we did not process (so that we don't drop it silently). + c.endNotifier.Notify(unreadMessages) + return + } + } +} + +// Process a message from a local peer. +func (c *Conference) processPeerMessage(message common.Message[ParticipantID, peer.MessageContent]) { + participant := c.getParticipant(message.Sender, errors.New("received a message from a deleted participant")) + if participant == nil { + return + } + + // Since Go does not support ADTs, we have to use a switch statement to + // determine the actual type of the message. + switch msg := message.Content.(type) { + case peer.JoinedTheCall: + c.processJoinedTheCallMessage(participant, msg) + case peer.LeftTheCall: + c.processLeftTheCallMessage(participant, msg) + case peer.NewTrackPublished: + c.processNewTrackPublishedMessage(participant, msg) + case peer.PublishedTrackFailed: + c.processPublishedTrackFailedMessage(participant, msg) + case peer.NewICECandidate: + c.processNewICECandidateMessage(participant, msg) + case peer.ICEGatheringComplete: + c.processICEGatheringCompleteMessage(participant, msg) + case peer.RenegotiationRequired: + c.processRenegotiationRequiredMessage(participant, msg) + case peer.DataChannelMessage: + c.processDataChannelMessage(participant, msg) + case peer.DataChannelAvailable: + c.processDataChannelAvailableMessage(participant, msg) + case peer.RTCPReceived: + c.processForwardRTCPMessage(msg) + default: + c.logger.Errorf("Unknown message type: %T", msg) + } +} + +func (c *Conference) processMatrixMessage(msg MatrixMessage) { + switch ev := msg.Content.(type) { + case *event.CallInviteEventContent: + c.onNewParticipant(msg.Sender, ev) + case *event.CallCandidatesEventContent: + c.onCandidates(msg.Sender, ev) + case *event.CallSelectAnswerEventContent: + c.onSelectAnswer(msg.Sender, ev) + case *event.CallHangupEventContent: + c.onHangup(msg.Sender, ev) + default: + c.logger.Errorf("Unexpected event type: %T", ev) + } +} diff --git a/pkg/conference/participant.go b/pkg/conference/participant.go new file mode 100644 index 0000000..59c2c9d --- /dev/null +++ b/pkg/conference/participant.go @@ -0,0 +1,59 @@ +package conference + +import ( + "encoding/json" + "time" + + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Things that we assume as identifiers for the participants in the call. +// There could be no 2 participants in the room with identical IDs. +type ParticipantID struct { + UserID id.UserID + DeviceID id.DeviceID + CallID string +} + +type PublishedTrack struct { + track *webrtc.TrackLocalStaticRTP + // The time when we sent the last PLI to the sender. We store this to avoid + // spamming the sender. + lastPLITimestamp time.Time +} + +// Participant represents a participant in the conference. +type Participant struct { + id ParticipantID + logger *logrus.Entry + peer *peer.Peer[ParticipantID] + remoteSessionID id.SessionID + streamMetadata event.CallSDPStreamMetadata + publishedTracks map[event.SFUTrackDescription]PublishedTrack +} + +func (p *Participant) asMatrixRecipient() signaling.MatrixRecipient { + return signaling.MatrixRecipient{ + UserID: p.id.UserID, + DeviceID: p.id.DeviceID, + CallID: p.id.CallID, + RemoteSessionID: p.remoteSessionID, + } +} + +func (p *Participant) sendDataChannelMessage(toSend event.SFUMessage) { + jsonToSend, err := json.Marshal(toSend) + if err != nil { + p.logger.Error("Failed to marshal data channel message") + } + + if err := p.peer.SendOverDataChannel(string(jsonToSend)); err != nil { + // TODO: We must buffer the message in this case and re-send it once the data channel is recovered! + p.logger.Error("Failed to send data channel message") + } +} diff --git a/pkg/conference/peer_message_processor.go b/pkg/conference/peer_message_processor.go new file mode 100644 index 0000000..a815972 --- /dev/null +++ b/pkg/conference/peer_message_processor.go @@ -0,0 +1,128 @@ +package conference + +import ( + "encoding/json" + "time" + + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" +) + +func (c *Conference) processJoinedTheCallMessage(participant *Participant, message peer.JoinedTheCall) { + participant.logger.Info("Joined the call") +} + +func (c *Conference) processLeftTheCallMessage(participant *Participant, msg peer.LeftTheCall) { + participant.logger.Info("Left the call: %s", msg.Reason) + c.removeParticipant(participant.id) + c.signaling.SendHangup(participant.asMatrixRecipient(), msg.Reason) +} + +func (c *Conference) processNewTrackPublishedMessage(participant *Participant, msg peer.NewTrackPublished) { + participant.logger.Infof("Published new track: %s", msg.Track.ID()) + key := event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + } + + if _, ok := participant.publishedTracks[key]; ok { + c.logger.Errorf("Track already published: %v", key) + return + } + + participant.publishedTracks[key] = PublishedTrack{track: msg.Track} + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processPublishedTrackFailedMessage(participant *Participant, msg peer.PublishedTrackFailed) { + participant.logger.Infof("Failed published track: %s", msg.Track.ID()) + delete(participant.publishedTracks, event.SFUTrackDescription{ + StreamID: msg.Track.StreamID(), + TrackID: msg.Track.ID(), + }) + + for _, otherParticipant := range c.participants { + if otherParticipant.id == participant.id { + continue + } + + otherParticipant.peer.UnsubscribeFrom([]*webrtc.TrackLocalStaticRTP{msg.Track}) + } + + c.resendMetadataToAllExcept(participant.id) +} + +func (c *Conference) processNewICECandidateMessage(participant *Participant, msg peer.NewICECandidate) { + participant.logger.Debug("Received a new local ICE candidate") + + // Convert WebRTC ICE candidate to Matrix ICE candidate. + jsonCandidate := msg.Candidate.ToJSON() + candidates := []event.CallCandidate{{ + Candidate: jsonCandidate.Candidate, + SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), + SDPMID: *jsonCandidate.SDPMid, + }} + c.signaling.SendICECandidates(participant.asMatrixRecipient(), candidates) +} + +func (c *Conference) processICEGatheringCompleteMessage(participant *Participant, msg peer.ICEGatheringComplete) { + participant.logger.Info("Completed local ICE gathering") + + // Send an empty array of candidates to indicate that ICE gathering is complete. + c.signaling.SendCandidatesGatheringFinished(participant.asMatrixRecipient()) +} + +func (c *Conference) processRenegotiationRequiredMessage(participant *Participant, msg peer.RenegotiationRequired) { + participant.logger.Info("Started renegotiation") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationOffer, + SDP: msg.Offer.SDP, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processDataChannelMessage(participant *Participant, msg peer.DataChannelMessage) { + participant.logger.Debug("Received data channel message") + var sfuMessage event.SFUMessage + if err := json.Unmarshal([]byte(msg.Message), &sfuMessage); err != nil { + c.logger.Errorf("Failed to unmarshal SFU message: %v", err) + return + } + + switch sfuMessage.Op { + case event.SFUOperationSelect: + c.processSelectDCMessage(participant, sfuMessage) + case event.SFUOperationAnswer: + c.processAnswerDCMessage(participant, sfuMessage) + case event.SFUOperationPublish: + c.processPublishDCMessage(participant, sfuMessage) + case event.SFUOperationUnpublish: + c.processUnpublishDCMessage(participant) + case event.SFUOperationAlive: + c.processAliveDCMessage(participant) + case event.SFUOperationMetadata: + c.processMetadataDCMessage(participant, sfuMessage) + } +} + +func (c *Conference) processDataChannelAvailableMessage(participant *Participant, msg peer.DataChannelAvailable) { + participant.logger.Info("Connected data channel") + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participant.id), + }) +} + +func (c *Conference) processForwardRTCPMessage(msg peer.RTCPReceived) { + for _, participant := range c.participants { + for _, publishedTrack := range participant.publishedTracks { + if publishedTrack.track.StreamID() == msg.StreamID && publishedTrack.track.ID() == msg.TrackID { + err := participant.peer.WriteRTCP(msg.Packets, msg.StreamID, msg.TrackID, publishedTrack.lastPLITimestamp) + if err == nil { + publishedTrack.lastPLITimestamp = time.Now() + } + } + } + } +} diff --git a/pkg/conference/start.go b/pkg/conference/start.go new file mode 100644 index 0000000..91eadac --- /dev/null +++ b/pkg/conference/start.go @@ -0,0 +1,64 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conference + +import ( + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// Starts a new conference or fails and returns an error. +func StartConference( + confID string, + config Config, + signaling signaling.MatrixSignaling, + conferenceEndNotifier ConferenceEndNotifier, + userID id.UserID, + inviteEvent *event.CallInviteEventContent, +) (*common.Sender[MatrixMessage], error) { + sender, receiver := common.NewChannel[MatrixMessage]() + + conference := &Conference{ + id: confID, + config: config, + signaling: signaling, + matrixMessages: receiver, + endNotifier: conferenceEndNotifier, + participants: make(map[ParticipantID]*Participant), + peerMessages: make(chan common.Message[ParticipantID, peer.MessageContent], common.UnboundedChannelSize), + logger: logrus.WithFields(logrus.Fields{"conf_id": confID}), + } + + participantID := ParticipantID{UserID: userID, DeviceID: inviteEvent.DeviceID, CallID: inviteEvent.CallID} + if err := conference.onNewParticipant(participantID, inviteEvent); err != nil { + return nil, err + } + + // Start conference "main loop". + go conference.processMessages() + + return &sender, nil +} + +type ConferenceEndNotifier interface { + // Called when the conference ends. + Notify(unread []MatrixMessage) +} diff --git a/pkg/conference/state.go b/pkg/conference/state.go new file mode 100644 index 0000000..de43c00 --- /dev/null +++ b/pkg/conference/state.go @@ -0,0 +1,120 @@ +package conference + +import ( + "github.com/matrix-org/waterfall/pkg/common" + "github.com/matrix-org/waterfall/pkg/peer" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" +) + +// A single conference. Call and conference mean the same in context of Matrix. +type Conference struct { + id string + config Config + logger *logrus.Entry + endNotifier ConferenceEndNotifier + + signaling signaling.MatrixSignaling + participants map[ParticipantID]*Participant + + peerMessages chan common.Message[ParticipantID, peer.MessageContent] + matrixMessages common.Receiver[MatrixMessage] +} + +func (c *Conference) getParticipant(participantID ParticipantID, optionalErrorMessage error) *Participant { + participant, ok := c.participants[participantID] + if !ok { + if optionalErrorMessage != nil { + c.logger.WithError(optionalErrorMessage) + } else { + c.logger.Error("Participant not found") + } + + return nil + } + + return participant +} + +// Helper to terminate and remove a participant from the conference. +func (c *Conference) removeParticipant(participantID ParticipantID) { + participant := c.getParticipant(participantID, nil) + if participant == nil { + return + } + + // Terminate the participant and remove it from the list. + participant.peer.Terminate() + delete(c.participants, participantID) + + // Inform the other participants about updated metadata (since the participant left + // the corresponding streams of the participant are no longer available, so we're informing + // others about it). + c.resendMetadataToAllExcept(participantID) + + // Remove the participant's tracks from all participants who might have subscribed to them. + obsoleteTracks := []*webrtc.TrackLocalStaticRTP{} + for _, publishedTrack := range participant.publishedTracks { + obsoleteTracks = append(obsoleteTracks, publishedTrack.track) + } + for _, otherParticipant := range c.participants { + otherParticipant.peer.UnsubscribeFrom(obsoleteTracks) + } +} + +// Helper to get the list of available streams for a given participant, i.e. the list of streams +// that a given participant **can subscribe to**. Each stream may have multiple tracks. +func (c *Conference) getAvailableStreamsFor(forParticipant ParticipantID) event.CallSDPStreamMetadata { + streamsMetadata := make(event.CallSDPStreamMetadata) + for id, participant := range c.participants { + // Skip us. As we know about our own tracks. + if forParticipant != id { + // Now, find out which of published tracks belong to the streams for which we have metadata + // available and construct a metadata map for a given participant based on that. + for _, track := range participant.publishedTracks { + trackID, streamID := track.track.ID(), track.track.StreamID() + + if metadata, ok := streamsMetadata[streamID]; ok { + metadata.Tracks[trackID] = event.CallSDPStreamMetadataTrack{} + streamsMetadata[streamID] = metadata + } else if metadata, ok := participant.streamMetadata[streamID]; ok { + metadata.Tracks = event.CallSDPStreamMetadataTracks{trackID: event.CallSDPStreamMetadataTrack{}} + streamsMetadata[streamID] = metadata + } else { + participant.logger.Warnf("Don't have metadata for stream %s", streamID) + } + } + } + } + + return streamsMetadata +} + +// Helper that returns the list of streams inside this conference that match the given stream IDs and track IDs. +func (c *Conference) getTracks(identifiers []event.SFUTrackDescription) []*webrtc.TrackLocalStaticRTP { + tracks := make([]*webrtc.TrackLocalStaticRTP, 0) + for _, participant := range c.participants { + // Check if this participant has any of the tracks that we're looking for. + for _, identifier := range identifiers { + if track, ok := participant.publishedTracks[identifier]; ok { + tracks = append(tracks, track.track) + } + } + } + + return tracks +} + +// Helper that sends current metadata about all available tracks to all participants except a given one. +func (c *Conference) resendMetadataToAllExcept(exceptMe ParticipantID) { + for participantID, participant := range c.participants { + if participantID != exceptMe { + participant.sendDataChannelMessage(event.SFUMessage{ + Op: event.SFUOperationMetadata, + Metadata: c.getAvailableStreamsFor(participantID), + }) + } + } +} diff --git a/src/config.go b/pkg/config/config.go similarity index 59% rename from src/config.go rename to pkg/config/config.go index c76e093..f9fde9e 100644 --- a/src/config.go +++ b/pkg/config/config.go @@ -1,40 +1,38 @@ -package main +package config import ( "errors" "fmt" "os" + "github.com/matrix-org/waterfall/pkg/conference" + "github.com/matrix-org/waterfall/pkg/signaling" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - "maunium.net/go/mautrix/id" ) -// The mandatory SFU configuration. +// SFU configuration. type Config struct { - // The Matrix ID (MXID) of the SFU. - UserID id.UserID - // The ULR of the homeserver that SFU talks to. - HomeserverURL string - // The access token for the Matrix SDK. - AccessToken string - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. - KeepAliveTimeout int + // Matrix configuration. + Matrix signaling.Config `yaml:"matrix"` + // Conference (call) configuration. + Conference conference.Config `yaml:"conference"` + // Starting from which level to log stuff. + LogLevel string `yaml:"log"` } // Tries to load a config from the `CONFIG` environment variable. // If the environment variable is not set, tries to load a config from the // provided path to the config file (YAML). Returns an error if the config could // not be loaded. -func loadConfig(path string) (*Config, error) { - config, err := loadConfigFromEnv() +func LoadConfig(path string) (*Config, error) { + config, err := LoadConfigFromEnv() if err != nil { if !errors.Is(err, ErrNoConfigEnvVar) { return nil, err } - return loadConfigFromPath(path) + return LoadConfigFromPath(path) } return config, nil @@ -45,17 +43,17 @@ var ErrNoConfigEnvVar = errors.New("environment variable not set or invalid") // Tries to load the config from environment variable (`CONFIG`). // Returns an error if not all environment variables are set. -func loadConfigFromEnv() (*Config, error) { +func LoadConfigFromEnv() (*Config, error) { configEnv := os.Getenv("CONFIG") if configEnv == "" { return nil, ErrNoConfigEnvVar } - return loadConfigFromString(configEnv) + return LoadConfigFromString(configEnv) } // Tries to load a config from the provided path. -func loadConfigFromPath(path string) (*Config, error) { +func LoadConfigFromPath(path string) (*Config, error) { logrus.WithField("path", path).Info("loading config") file, err := os.ReadFile(path) @@ -63,12 +61,12 @@ func loadConfigFromPath(path string) (*Config, error) { return nil, fmt.Errorf("failed to read file: %w", err) } - return loadConfigFromString(string(file)) + return LoadConfigFromString(string(file)) } // Load config from the provided string. // Returns an error if the string is not a valid YAML. -func loadConfigFromString(configString string) (*Config, error) { +func LoadConfigFromString(configString string) (*Config, error) { logrus.Info("loading config from string") var config Config @@ -76,5 +74,12 @@ func loadConfigFromString(configString string) (*Config, error) { return nil, fmt.Errorf("failed to unmarshal YAML file: %w", err) } + if config.Matrix.UserID == "" || + config.Matrix.HomeserverURL == "" || + config.Matrix.AccessToken == "" || + config.Conference.KeepAliveTimeout == 0 { + return nil, errors.New("invalid config values") + } + return &config, nil } diff --git a/pkg/main.go b/pkg/main.go new file mode 100644 index 0000000..bed8313 --- /dev/null +++ b/pkg/main.go @@ -0,0 +1,98 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "flag" + "os" + "os/signal" + "syscall" + + "github.com/matrix-org/waterfall/pkg/config" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" +) + +func main() { + // Parse command line flags. + var ( + configFilePath = flag.String("config", "config.yaml", "configuration file path") + cpuProfile = flag.String("cpuProfile", "", "write CPU profile to `file`") + memProfile = flag.String("memProfile", "", "write memory profile to `file`") + ) + flag.Parse() + + // Initialize logging subsystem (formatting, global logging framework etc). + logrus.SetFormatter(&logrus.TextFormatter{FullTimestamp: true, ForceColors: true}) + + // Define functions that are called before exiting. + // This is useful to stop the profiler if it's enabled. + deferred_functions := []func(){} + if *cpuProfile != "" { + deferred_functions = append(deferred_functions, InitCPUProfiling(cpuProfile)) + } + if *memProfile != "" { + deferred_functions = append(deferred_functions, InitMemoryProfiling(memProfile)) + } + + // Handle signal interruptions. + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + <-c + for _, function := range deferred_functions { + function() + } + os.Exit(0) + }() + + // Load the config file from the environment variable or path. + config, err := config.LoadConfig(*configFilePath) + if err != nil { + logrus.WithError(err).Fatal("could not load config") + return + } + + switch config.LogLevel { + case "debug": + logrus.SetLevel(logrus.DebugLevel) + case "info": + logrus.SetLevel(logrus.InfoLevel) + case "warn": + logrus.SetLevel(logrus.WarnLevel) + case "error": + logrus.SetLevel(logrus.ErrorLevel) + case "fatal": + logrus.SetLevel(logrus.FatalLevel) + case "panic": + logrus.SetLevel(logrus.PanicLevel) + default: + logrus.SetLevel(logrus.InfoLevel) + } + + // Create matrix client. + matrixClient := signaling.NewMatrixClient(config.Matrix) + + // Create a router to route incoming To-Device messages to the right conference. + routerChannel := newRouter(matrixClient, config.Conference) + + // Start matrix client sync. This function will block until the sync fails. + matrixClient.RunSyncing(func(e *event.Event) { + routerChannel <- e + }) +} diff --git a/pkg/peer/keepalive.go b/pkg/peer/keepalive.go new file mode 100644 index 0000000..b66f2ed --- /dev/null +++ b/pkg/peer/keepalive.go @@ -0,0 +1,23 @@ +package peer + +import "time" + +type HeartBeat struct{} + +// Starts a goroutine that will execute `onDeadLine` closure in case nothing has been published +// to the `heartBeat` channel for `deadline` duration. The goroutine stops once the channel is closed. +func startKeepAlive(deadline time.Duration, heartBeat <-chan HeartBeat, onDeadLine func()) { + go func() { + for { + select { + case <-time.After(deadline): + onDeadLine() + return + case _, ok := <-heartBeat: + if !ok { + return + } + } + } + }() +} diff --git a/pkg/peer/messages.go b/pkg/peer/messages.go new file mode 100644 index 0000000..0593d72 --- /dev/null +++ b/pkg/peer/messages.go @@ -0,0 +1,47 @@ +package peer + +import ( + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" +) + +// Due to the limitation of Go, we're using the `interface{}` to be able to use switch the actual +// type of the message on runtime. The underlying types do not necessary need to be structures. +type MessageContent = interface{} + +type JoinedTheCall struct{} + +type LeftTheCall struct { + Reason event.CallHangupReason +} + +type NewTrackPublished struct { + Track *webrtc.TrackLocalStaticRTP +} + +type PublishedTrackFailed struct { + Track *webrtc.TrackLocalStaticRTP +} + +type NewICECandidate struct { + Candidate *webrtc.ICECandidate +} + +type ICEGatheringComplete struct{} + +type RenegotiationRequired struct { + Offer *webrtc.SessionDescription +} + +type DataChannelMessage struct { + Message string +} + +type DataChannelAvailable struct{} + +type RTCPReceived struct { + Packets []rtcp.Packet + StreamID string + TrackID string +} diff --git a/pkg/peer/peer.go b/pkg/peer/peer.go new file mode 100644 index 0000000..a226633 --- /dev/null +++ b/pkg/peer/peer.go @@ -0,0 +1,272 @@ +package peer + +import ( + "errors" + "io" + "sync" + "time" + + "github.com/matrix-org/waterfall/pkg/common" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v3" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/event" +) + +var ( + ErrCantCreatePeerConnection = errors.New("can't create peer connection") + ErrCantSetRemoteDescription = errors.New("can't set remote description") + ErrCantCreateAnswer = errors.New("can't create answer") + ErrCantSetLocalDescription = errors.New("can't set local description") + ErrCantCreateLocalDescription = errors.New("can't create local description") + ErrDataChannelNotAvailable = errors.New("data channel is not available") + ErrDataChannelNotReady = errors.New("data channel is not ready") + ErrCantSubscribeToTrack = errors.New("can't subscribe to track") + ErrCantWriteRTCP = errors.New("can't write RTCP") +) + +// A wrapped representation of the peer connection (single peer in the call). +// The peer gets information about the things happening outside via public methods +// and informs the outside world about the things happening inside the peer by posting +// the messages to the channel. +type Peer[ID comparable] struct { + logger *logrus.Entry + peerConnection *webrtc.PeerConnection + sink *common.MessageSink[ID, MessageContent] + heartbeat chan HeartBeat + + dataChannelMutex sync.Mutex + dataChannel *webrtc.DataChannel +} + +// Instantiates a new peer with a given SDP offer and returns a peer and the SDP answer if everything is ok. +func NewPeer[ID comparable]( + sdpOffer string, + sink *common.MessageSink[ID, MessageContent], + logger *logrus.Entry, + keepAliveDeadline time.Duration, +) (*Peer[ID], *webrtc.SessionDescription, error) { + peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + logger.WithError(err).Error("failed to create peer connection") + return nil, nil, ErrCantCreatePeerConnection + } + + peer := &Peer[ID]{ + logger: logger, + peerConnection: peerConnection, + sink: sink, + heartbeat: make(chan HeartBeat, common.UnboundedChannelSize), + } + + peerConnection.OnTrack(peer.onRtpTrackReceived) + peerConnection.OnDataChannel(peer.onDataChannelReady) + peerConnection.OnICECandidate(peer.onICECandidateGathered) + peerConnection.OnNegotiationNeeded(peer.onNegotiationNeeded) + peerConnection.OnICEConnectionStateChange(peer.onICEConnectionStateChanged) + peerConnection.OnICEGatheringStateChange(peer.onICEGatheringStateChanged) + peerConnection.OnConnectionStateChange(peer.onConnectionStateChanged) + peerConnection.OnSignalingStateChange(peer.onSignalingStateChanged) + + if sdpAnswer, err := peer.ProcessSDPOffer(sdpOffer); err != nil { + return nil, nil, err + } else { + onDeadline := func() { peer.sink.Send(LeftTheCall{event.CallHangupKeepAliveTimeout}) } + startKeepAlive(keepAliveDeadline, peer.heartbeat, onDeadline) + return peer, sdpAnswer, nil + } +} + +// Closes peer connection. From this moment on, no new messages will be sent from the peer. +func (p *Peer[ID]) Terminate() { + if err := p.peerConnection.Close(); err != nil { + p.logger.WithError(err).Error("failed to close peer connection") + } + + // We want to seal the channel since the sender is not interested in us anymore. + // We may want to remove this logic if/once we want to receive messages (confirmation of close or whatever) + // from the peer that is considered closed. + p.sink.Seal() +} + +// Adds given track to our peer connection, so that it can be sent to the remote peer. +func (p *Peer[ID]) SubscribeTo(track *webrtc.TrackLocalStaticRTP) error { + rtpSender, err := p.peerConnection.AddTrack(track) + if err != nil { + p.logger.WithError(err).Error("failed to add track") + return ErrCantSubscribeToTrack + } + + // Read incoming RTCP packets + // Before these packets are returned they are processed by interceptors. For things + // like NACK this needs to be called. + go func() { + for { + packets, _, err := rtpSender.ReadRTCP() + if err != nil { + if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { + return + } + + p.logger.WithError(err).Warn("failed to read RTCP on track") + } + + p.sink.Send(RTCPReceived{Packets: packets, TrackID: track.ID(), StreamID: track.StreamID()}) + } + }() + + return nil +} + +func (p *Peer[ID]) WriteRTCP(packets []rtcp.Packet, streamID string, trackID string, lastPLITimestamp time.Time) error { + const minimalPLIInterval = time.Millisecond * 500 + + packetsToSend := []rtcp.Packet{} + var mediaSSRC uint32 + receivers := p.peerConnection.GetReceivers() + receiverIndex := slices.IndexFunc(receivers, func(receiver *webrtc.RTPReceiver) bool { + return receiver.Track().ID() == trackID && receiver.Track().StreamID() == streamID + }) + + if receiverIndex == -1 { + p.logger.Error("failed to find track to write RTCP on") + return ErrCantWriteRTCP + } else { + mediaSSRC = uint32(receivers[receiverIndex].Track().SSRC()) + } + + for _, packet := range packets { + switch typedPacket := packet.(type) { + // We mung the packets here, so that the SSRCs match what the + // receiver expects: + // The media SSRC is the SSRC of the media about which the packet is + // reporting; therefore, we mung it to be the SSRC of the publishing + // participant's track. Without this, it would be SSRC of the SFU's + // track which isn't right + case *rtcp.PictureLossIndication: + // Since we sometimes spam the sender with PLIs, make sure we don't send + // them way too often + if time.Now().UnixNano()-lastPLITimestamp.UnixNano() < minimalPLIInterval.Nanoseconds() { + continue + } + + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + case *rtcp.FullIntraRequest: + typedPacket.MediaSSRC = mediaSSRC + packetsToSend = append(packetsToSend, typedPacket) + } + + packetsToSend = append(packetsToSend, packet) + } + + if len(packetsToSend) != 0 { + if err := p.peerConnection.WriteRTCP(packetsToSend); err != nil { + if !errors.Is(err, io.ErrClosedPipe) { + p.logger.WithError(err).Error("failed to write RTCP on track") + return err + } + } + } + + return nil +} + +// Unsubscribes from the given list of tracks. +func (p *Peer[ID]) UnsubscribeFrom(tracks []*webrtc.TrackLocalStaticRTP) { + // That's unfortunately an O(m*n) operation, but we don't expect the number of tracks to be big. + for _, sender := range p.peerConnection.GetSenders() { + currentTrack := sender.Track() + if currentTrack == nil { + return + } + + for _, trackToUnsubscribe := range tracks { + presentTrackID, presentStreamID := currentTrack.ID(), currentTrack.StreamID() + trackID, streamID := trackToUnsubscribe.ID(), trackToUnsubscribe.StreamID() + if presentTrackID == trackID && presentStreamID == streamID { + if err := p.peerConnection.RemoveTrack(sender); err != nil { + p.logger.WithError(err).Error("failed to remove track") + } + } + } + } +} + +// Tries to send the given message to the remote counterpart of our peer. +func (p *Peer[ID]) SendOverDataChannel(json string) error { + p.dataChannelMutex.Lock() + defer p.dataChannelMutex.Unlock() + + if p.dataChannel == nil { + p.logger.Error("can't send data over data channel: data channel is not ready") + return ErrDataChannelNotAvailable + } + + if p.dataChannel.ReadyState() != webrtc.DataChannelStateOpen { + p.logger.Error("can't send data over data channel: data channel is not open") + return ErrDataChannelNotReady + } + + if err := p.dataChannel.SendText(json); err != nil { + p.logger.WithError(err).Error("failed to send data over data channel") + } + + return nil +} + +// Processes the remote ICE candidates. +func (p *Peer[ID]) ProcessNewRemoteCandidates(candidates []webrtc.ICECandidateInit) { + for _, candidate := range candidates { + if err := p.peerConnection.AddICECandidate(candidate); err != nil { + p.logger.WithError(err).Error("failed to add ICE candidate") + } + } +} + +// Processes the SDP answer received from the remote peer. +func (p *Peer[ID]) ProcessSDPAnswer(sdpAnswer string) error { + err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: sdpAnswer, + }) + if err != nil { + p.logger.WithError(err).Error("failed to set remote description") + return ErrCantSetRemoteDescription + } + + return nil +} + +// Applies the sdp offer received from the remote peer and generates an SDP answer. +func (p *Peer[ID]) ProcessSDPOffer(sdpOffer string) (*webrtc.SessionDescription, error) { + err := p.peerConnection.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + }) + if err != nil { + p.logger.WithError(err).Error("failed to set remote description") + return nil, ErrCantSetRemoteDescription + } + + answer, err := p.peerConnection.CreateAnswer(nil) + if err != nil { + p.logger.WithError(err).Error("failed to create answer") + return nil, ErrCantCreateAnswer + } + + if err := p.peerConnection.SetLocalDescription(answer); err != nil { + p.logger.WithError(err).Error("failed to set local description") + return nil, ErrCantSetLocalDescription + } + + return &answer, nil +} + +// New heartbeat received (keep-alive message that is periodically sent by the remote peer). +// We need to update the last heartbeat time. If the peer is not active for too long, we will +// consider peer's connection as stalled and will close it. +func (p *Peer[ID]) ProcessHeartbeat() { + p.heartbeat <- HeartBeat{} +} diff --git a/pkg/peer/webrtc.go b/pkg/peer/webrtc.go new file mode 100644 index 0000000..8806689 --- /dev/null +++ b/pkg/peer/webrtc.go @@ -0,0 +1,154 @@ +package peer + +import ( + "errors" + "io" + + "github.com/pion/webrtc/v3" + "maunium.net/go/mautrix/event" +) + +// A callback that is called once we receive first RTP packets from a track, i.e. +// we call this function each time a new track is received. +func (p *Peer[ID]) onRtpTrackReceived(remoteTrack *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + // Create a local track, all our SFU clients that are subscribed to this + // peer (publisher) wil be fed via this track. + localTrack, err := webrtc.NewTrackLocalStaticRTP( + remoteTrack.Codec().RTPCodecCapability, + remoteTrack.ID(), + remoteTrack.StreamID(), + ) + if err != nil { + p.logger.WithError(err).Error("failed to create local track") + return + } + + // Notify others that our track has just been published. + p.sink.Send(NewTrackPublished{Track: localTrack}) + + // Start forwarding the data from the remote track to the local track, + // so that everyone who is subscribed to this track will receive the data. + go func() { + rtpBuf := make([]byte, 1400) + + for { + index, _, readErr := remoteTrack.Read(rtpBuf) + if readErr != nil { + if readErr == io.EOF { // finished, no more data, no error, inform others + p.logger.Info("remote track closed") + } else { // finished, no more data, but with error, inform others + p.logger.WithError(readErr).Error("failed to read from remote track") + } + p.sink.Send(PublishedTrackFailed{Track: localTrack}) + return + } + + // ErrClosedPipe means we don't have any subscribers, this is ok if no peers have connected yet. + if _, err = localTrack.Write(rtpBuf[:index]); err != nil && !errors.Is(err, io.ErrClosedPipe) { + p.logger.WithError(err).Error("failed to write to local track") + p.sink.Send(PublishedTrackFailed{Track: localTrack}) + return + } + } + }() +} + +// A callback that is called once we receive an ICE candidate for this peer connection. +func (p *Peer[ID]) onICECandidateGathered(candidate *webrtc.ICECandidate) { + if candidate == nil { + p.logger.Info("ICE candidate gathering finished") + p.sink.Send(ICEGatheringComplete{}) + return + } + + p.logger.WithField("candidate", candidate).Debug("ICE candidate gathered") + p.sink.Send(NewICECandidate{Candidate: candidate}) +} + +// A callback that is called once we receive an ICE connection state change for this peer connection. +func (p *Peer[ID]) onNegotiationNeeded() { + p.logger.Debug("negotiation needed") + offer, err := p.peerConnection.CreateOffer(nil) + if err != nil { + p.logger.WithError(err).Error("failed to create offer") + return + } + + if err := p.peerConnection.SetLocalDescription(offer); err != nil { + p.logger.WithError(err).Error("failed to set local description") + return + } + + p.sink.Send(RenegotiationRequired{Offer: &offer}) +} + +// A callback that is called once we receive an ICE connection state change for this peer connection. +func (p *Peer[ID]) onICEConnectionStateChanged(state webrtc.ICEConnectionState) { + p.logger.WithField("state", state).Info("ICE connection state changed") + + switch state { + case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateDisconnected: + // TODO: Ask Simon if we should do it here as in the previous implementation. + // Ideally we want to perform an ICE restart here. + // p.notify <- PeerLeftTheCall{sender: p.data} + case webrtc.ICEConnectionStateCompleted, webrtc.ICEConnectionStateConnected: + // FIXME: Start keep-alive timer over the data channel to check the connecitons that hanged. + // p.notify <- PeerJoinedTheCall{sender: p.data} + } +} + +func (p *Peer[ID]) onICEGatheringStateChanged(state webrtc.ICEGathererState) { + p.logger.WithField("state", state).Debug("ICE gathering state changed") +} + +func (p *Peer[ID]) onSignalingStateChanged(state webrtc.SignalingState) { + p.logger.WithField("state", state).Debug("signaling state changed") +} + +func (p *Peer[ID]) onConnectionStateChanged(state webrtc.PeerConnectionState) { + p.logger.WithField("state", state).Info("Connection state changed") + + switch state { + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateClosed: + p.sink.Send(LeftTheCall{event.CallHangupUserHangup}) + case webrtc.PeerConnectionStateConnected: + p.sink.Send(JoinedTheCall{}) + } +} + +// A callback that is called once the data channel is ready to be used. +func (p *Peer[ID]) onDataChannelReady(dc *webrtc.DataChannel) { + p.dataChannelMutex.Lock() + defer p.dataChannelMutex.Unlock() + + if p.dataChannel != nil { + p.logger.Error("Data channel already exists") + p.dataChannel.Close() + return + } + + p.dataChannel = dc + p.logger.WithField("label", dc.Label()).Info("Data channel ready") + + dc.OnOpen(func() { + p.logger.Info("Data channel opened") + p.sink.Send(DataChannelAvailable{}) + }) + + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.logger.WithField("message", msg).Debug("Data channel message received") + if msg.IsString { + p.sink.Send(DataChannelMessage{Message: string(msg.Data)}) + } else { + p.logger.Warn("Data channel message is not a string, ignoring") + } + }) + + dc.OnError(func(err error) { + p.logger.WithError(err).Error("Data channel error") + }) + + dc.OnClose(func() { + p.logger.Info("Data channel closed") + }) +} diff --git a/pkg/profile.go b/pkg/profile.go new file mode 100644 index 0000000..6ca8b45 --- /dev/null +++ b/pkg/profile.go @@ -0,0 +1,53 @@ +package main + +import ( + "os" + "runtime" + "runtime/pprof" + + "github.com/sirupsen/logrus" +) + +// Initializes CPU profiling and returns a function to stop profiling. +func InitCPUProfiling(cpuProfile *string) func() { + logrus.Info("initializing CPU profiling") + + file, err := os.Create(*cpuProfile) + if err != nil { + logrus.WithError(err).Fatal("could not create CPU profile") + } + + if err := pprof.StartCPUProfile(file); err != nil { + logrus.WithError(err).Fatal("could not start CPU profile") + } + + return func() { + pprof.StopCPUProfile() + + if err := file.Close(); err != nil { + logrus.WithError(err).Fatal("could not close CPU profile") + } + } +} + +// Initializes memory profiling and returns a function to stop profiling. +func InitMemoryProfiling(memProfile *string) func() { + logrus.Info("initializing memory profiling") + + return func() { + file, err := os.Create(*memProfile) + if err != nil { + logrus.WithError(err).Fatal("could not create memory profile") + } + + runtime.GC() + + if err := pprof.WriteHeapProfile(file); err != nil { + logrus.WithError(err).Fatal("could not write memory profile") + } + + if err = file.Close(); err != nil { + logrus.WithError(err).Fatal("could not close memory profile") + } + } +} diff --git a/pkg/router.go b/pkg/router.go new file mode 100644 index 0000000..c502b18 --- /dev/null +++ b/pkg/router.go @@ -0,0 +1,203 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "github.com/matrix-org/waterfall/pkg/common" + conf "github.com/matrix-org/waterfall/pkg/conference" + "github.com/matrix-org/waterfall/pkg/signaling" + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type Conference = common.Sender[conf.MatrixMessage] + +// The top-level state of the Router. +type Router struct { + // Matrix matrix. + matrix *signaling.MatrixClient + // Sinks of all conferences (all calls that are currently forwarded by this SFU). + conferenceSinks map[string]*Conference + // Configuration for the calls. + config conf.Config + // A channel to serialize all incoming events to the Router. + channel chan RouterMessage +} + +// Creates a new instance of the SFU with the given configuration. +func newRouter(matrix *signaling.MatrixClient, config conf.Config) chan<- RouterMessage { + router := &Router{ + matrix: matrix, + conferenceSinks: make(map[string]*common.Sender[conf.MatrixMessage]), + config: config, + channel: make(chan RouterMessage, common.UnboundedChannelSize), + } + + // Start the main loop of the Router. + go func() { + for msg := range router.channel { + switch msg := msg.(type) { + // To-Device message received from the remote peer. + case MatrixMessage: + router.handleMatrixEvent(msg) + // One of the conferences has ended. + case ConferenceEndedMessage: + // Remove the conference that ended from the list. + delete(router.conferenceSinks, msg.conferenceID) + + // Process the message that was not read by the conference. + for _, msg := range msg.unread { + // TODO: We actually already know the type, so we can do this better. + router.handleMatrixEvent(msg.RawEvent) + } + } + } + }() + + return router.channel +} + +// Handles incoming To-Device events that the SFU receives from clients. +func (r *Router) handleMatrixEvent(evt *event.Event) { + var ( + conferenceID string + callID string + deviceID string + userID = evt.Sender + ) + + // Check if `conf_id` is present in the message (right messages do have it). + rawConferenceID, okConferenceId := evt.Content.Raw["conf_id"] + rawCallID, okCallId := evt.Content.Raw["call_id"] + rawDeviceID, okDeviceID := evt.Content.Raw["device_id"] + + if okConferenceId && okCallId && okDeviceID { + // Extract the conference ID from the message. + conferenceID, okConferenceId = rawConferenceID.(string) + callID, okCallId = rawCallID.(string) + deviceID, okDeviceID = rawDeviceID.(string) + + if !okConferenceId || !okCallId || !okDeviceID { + logrus.Warn("Ignoring invalid message without IDs") + return + } + } + + logger := logrus.WithFields(logrus.Fields{ + "type": evt.Type.Type, + "user_id": userID, + "conf_id": conferenceID, + "device_id": deviceID, + }) + + conference := r.conferenceSinks[conferenceID] + + // Only ToDeviceCallInvite events are allowed to create a new conference, others + // are expected to operate on an existing conference that is running on the SFU. + if conference == nil && evt.Type.Type == event.ToDeviceCallInvite.Type { + logger.Infof("creating new conference %s", conferenceID) + conferenceSink, err := conf.StartConference( + conferenceID, + r.config, + r.matrix.CreateForConference(conferenceID), + createConferenceEndNotifier(conferenceID, r.channel), + userID, + evt.Content.AsCallInvite(), + ) + if err != nil { + logger.WithError(err).Errorf("failed to start conference %s", conferenceID) + return + } + + r.conferenceSinks[conferenceID] = conferenceSink + return + } + + // All other events are expected to be handled by the existing conference. + if conference == nil { + logger.Warnf("ignoring %s since the conference is unknown", evt.Type.Type) + return + } + + // A helper function to deal with messages that can't be sent due to the conference closed. + // Not a function due to the need to capture local environment. + sendToConference := func(eventContent conf.MessageContent) { + sender := conf.ParticipantID{userID, id.DeviceID(deviceID), callID} + // At this point the conference is not nil. + // Let's check if the channel is still available. + if conference.Send(conf.MatrixMessage{Content: eventContent, RawEvent: evt, Sender: sender}) != nil { + // If sending failed, then the conference is over. + delete(r.conferenceSinks, conferenceID) + // Since we were not able to send the message, let's re-process it now. + // Note, we probably do not want to block here! + r.handleMatrixEvent(evt) + } + } + + switch evt.Type.Type { + // Someone tries to participate in a call (join a call). + case event.ToDeviceCallInvite.Type: + // If there is an invitation sent and the conference does not exist, create one. + sendToConference(evt.Content.AsCallInvite()) + case event.ToDeviceCallCandidates.Type: + // Someone tries to send ICE candidates to the existing call. + sendToConference(evt.Content.AsCallCandidates()) + case event.ToDeviceCallSelectAnswer.Type: + // Someone informs us about them accepting our (SFU's) SDP answer for an existing call. + sendToConference(evt.Content.AsCallSelectAnswer()) + case event.ToDeviceCallHangup.Type: + // Someone tries to inform us about leaving an existing call. + sendToConference(evt.Content.AsCallHangup()) + default: + logger.Warnf("ignoring event that we must not receive: %s", evt.Type.Type) + } +} + +type RouterMessage = interface{} + +type MatrixMessage = *event.Event + +// Message that is sent from the conference when the conference is ended. +type ConferenceEndedMessage struct { + // The ID of the conference that has ended. + conferenceID string + // A message (or messages in future) that has not been processed (if any). + unread []conf.MatrixMessage +} + +// A simple wrapper around channel that contains the ID of the conference that sent the message. +type ConferenceEndNotifier struct { + conferenceID string + channel chan<- interface{} +} + +// Crates a simple notifier with a conference with a given ID. +func createConferenceEndNotifier(conferenceID string, channel chan<- RouterMessage) *ConferenceEndNotifier { + return &ConferenceEndNotifier{ + conferenceID: conferenceID, + channel: channel, + } +} + +// A function that a conference calls when it is ended. +func (c *ConferenceEndNotifier) Notify(unread []conf.MatrixMessage) { + c.channel <- ConferenceEndedMessage{ + conferenceID: c.conferenceID, + unread: unread, + } +} diff --git a/pkg/signaling/client.go b/pkg/signaling/client.go new file mode 100644 index 0000000..c00e519 --- /dev/null +++ b/pkg/signaling/client.go @@ -0,0 +1,67 @@ +package signaling + +import ( + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +type MatrixClient struct { + client *mautrix.Client +} + +func NewMatrixClient(config Config) *MatrixClient { + client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) + if err != nil { + logrus.WithError(err).Fatal("Failed to create client") + } + + whoami, err := client.Whoami() + if err != nil { + logrus.WithError(err).Fatal("Failed to identify SFU user") + } + + if config.UserID != whoami.UserID { + logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") + } + + logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") + client.DeviceID = whoami.DeviceID + + return &MatrixClient{ + client: client, + } +} + +// Starts the Matrix client and connects to the homeserver, +// Returns only when the sync with Matrix fails. +func (m *MatrixClient) RunSyncing(callback func(*event.Event)) { + syncer, ok := m.client.Syncer.(*mautrix.DefaultSyncer) + if !ok { + logrus.Panic("Syncer is not DefaultSyncer") + } + + syncer.ParseEventContent = true + syncer.OnEvent(func(_ mautrix.EventSource, evt *event.Event) { + // We only care about to-device events. + if evt.Type.Class != event.ToDeviceEventType { + logrus.Warn("ignoring a not to-device event") + return + } + + // We drop the messages if they are not meant for us. + if evt.Content.Raw["dest_session_id"] != LocalSessionID { + logrus.Warn("SessionID does not match our SessionID - ignoring") + return + } + + callback(evt) + }) + + // TODO: We may want to reconnect if `Sync()` fails instead of ending the SFU + // as ending here will essentially drop all conferences which may not necessarily + // be what we want for the existing running conferences. + if err := m.client.Sync(); err != nil { + logrus.WithError(err).Panic("Sync failed") + } +} diff --git a/pkg/signaling/config.go b/pkg/signaling/config.go new file mode 100644 index 0000000..acd730c --- /dev/null +++ b/pkg/signaling/config.go @@ -0,0 +1,13 @@ +package signaling + +import "maunium.net/go/mautrix/id" + +// Configuration for the Matrix client. +type Config struct { + // The Matrix ID (MXID) of the SFU. + UserID id.UserID `yaml:"userid"` + // The ULR of the homeserver that SFU talks to. + HomeserverURL string `yaml:"homeserverurl"` + // The access token for the Matrix SDK. + AccessToken string `yaml:"accesstoken"` +} diff --git a/pkg/signaling/matrix.go b/pkg/signaling/matrix.go new file mode 100644 index 0000000..245f022 --- /dev/null +++ b/pkg/signaling/matrix.go @@ -0,0 +1,148 @@ +/* +Copyright 2022 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package signaling + +import ( + "github.com/sirupsen/logrus" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const LocalSessionID = "sfu" + +// Matrix client scoped for a particular conference. +type MatrixForConference struct { + client *mautrix.Client + conferenceID string +} + +// Create a new Matrix client that abstracts outgoing Matrix messages from a given conference. +func (m *MatrixClient) CreateForConference(conferenceID string) *MatrixForConference { + return &MatrixForConference{ + client: m.client, + conferenceID: conferenceID, + } +} + +// Defines the data that identifies a receiver of Matrix's to-device message. +type MatrixRecipient struct { + UserID id.UserID + DeviceID id.DeviceID + RemoteSessionID id.SessionID + CallID string +} + +// Interface that abstracts sending Send-to-device messages for the conference. +type MatrixSignaling interface { + SendSDPAnswer(recipient MatrixRecipient, streamMetadata event.CallSDPStreamMetadata, sdp string) + SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) + SendCandidatesGatheringFinished(recipient MatrixRecipient) + SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) + DeviceID() id.DeviceID +} + +func (m *MatrixForConference) SendSDPAnswer( + recipient MatrixRecipient, + streamMetadata event.CallSDPStreamMetadata, + sdp string, +) { + eventContent := &event.Content{ + Parsed: event.CallAnswerEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), + Answer: event.CallData{ + Type: "answer", + SDP: sdp, + }, + SDPStreamMetadata: streamMetadata, + }, + } + + m.sendToDevice(recipient, event.CallAnswer, eventContent) +} + +func (m *MatrixForConference) SendICECandidates(recipient MatrixRecipient, candidates []event.CallCandidate) { + eventContent := &event.Content{ + Parsed: event.CallCandidatesEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), + Candidates: candidates, + }, + } + + m.sendToDevice(recipient, event.CallCandidates, eventContent) +} + +func (m *MatrixForConference) SendCandidatesGatheringFinished(recipient MatrixRecipient) { + eventContent := &event.Content{ + Parsed: event.CallCandidatesEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), + Candidates: []event.CallCandidate{{Candidate: ""}}, + }, + } + + m.sendToDevice(recipient, event.CallCandidates, eventContent) +} + +func (m *MatrixForConference) SendHangup(recipient MatrixRecipient, reason event.CallHangupReason) { + eventContent := &event.Content{ + Parsed: event.CallHangupEventContent{ + BaseCallEventContent: m.createBaseEventContent(recipient.CallID, recipient.RemoteSessionID), + Reason: reason, + }, + } + + m.sendToDevice(recipient, event.CallHangup, eventContent) +} + +func (m *MatrixForConference) DeviceID() id.DeviceID { + return m.client.DeviceID +} + +func (m *MatrixForConference) createBaseEventContent( + callID string, + destSessionID id.SessionID, +) event.BaseCallEventContent { + return event.BaseCallEventContent{ + CallID: callID, + ConfID: m.conferenceID, + DeviceID: m.client.DeviceID, + SenderSessionID: LocalSessionID, + DestSessionID: destSessionID, + PartyID: string(m.client.DeviceID), + Version: event.CallVersion("1"), + } +} + +// Sends a to-device event to the given user. +func (m *MatrixForConference) sendToDevice(user MatrixRecipient, eventType event.Type, eventContent *event.Content) { + logger := logrus.WithFields(logrus.Fields{ + "user_id": user.UserID, + "device_id": user.DeviceID, + }) + + sendRequest := &mautrix.ReqSendToDevice{ + Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + user.UserID: { + user.DeviceID: eventContent, + }, + }, + } + + if _, err := m.client.SendToDevice(eventType, sendRequest); err != nil { + logger.Errorf("failed to send to-device event: %w", err) + } +} diff --git a/scripts/build.sh b/scripts/build.sh index 9702abb..56595bf 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go build -o dist/waterfall ./src +go build -o dist/waterfall ./pkg diff --git a/scripts/profile.sh b/scripts/profile.sh index 071103e..7bc0ee9 100755 --- a/scripts/profile.sh +++ b/scripts/profile.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./src/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof --logTime +go run ./pkg/*.go --cpuProfile cpuProfile.pprof --memProfile memProfile.pprof diff --git a/scripts/run.sh b/scripts/run.sh index 50c2f32..6381271 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -go run ./src --logTime +go run ./pkg diff --git a/src/call.go b/src/call.go deleted file mode 100644 index 490c021..0000000 --- a/src/call.go +++ /dev/null @@ -1,565 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "encoding/json" - "sync" - "time" - - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type Call struct { - PeerConnection *webrtc.PeerConnection - - CallID string - UserID id.UserID - DeviceID id.DeviceID - LocalSessionID id.SessionID - RemoteSessionID id.SessionID - - Publishers []*Publisher - Subscribers []*Subscriber - - mutex sync.RWMutex - logger *logrus.Entry - client *mautrix.Client - conf *Conference - - dataChannel *webrtc.DataChannel - lastKeepAliveTimestamp time.Time - sentEndOfCandidates bool -} - -func NewCall(callID string, conf *Conference) *Call { - call := new(Call) - - call.CallID = callID - call.conf = conf - - return call -} - -func (c *Call) InitWithInvite(evt *event.Event, client *mautrix.Client) { - invite := evt.Content.AsCallInvite() - - c.UserID = evt.Sender - c.DeviceID = invite.DeviceID - // XXX: What if an SFU gets restarted? - c.LocalSessionID = localSessionID - c.RemoteSessionID = invite.SenderSessionID - c.client = client - c.logger = logrus.WithFields(logrus.Fields{ - "user_id": evt.Sender, - "conf_id": invite.ConfID, - }) -} - -func (c *Call) onDCSelect(start []event.SFUTrackDescription) { - if len(start) == 0 { - return - } - - for _, trackDesc := range start { - trackLogger := c.logger.WithFields(logrus.Fields{ - "track_id": trackDesc.TrackID, - "stream_id": trackDesc.StreamID, - }) - - trackLogger.Info("selecting track") - - for _, publisher := range c.conf.GetPublishers() { - if publisher.Matches(trackDesc) { - publisher.Subscribe(c) - } - } - } -} - -func (c *Call) onDCPublish(sdp string) { - c.logger.Info("received DC publish") - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } - - offer, err := c.PeerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: offer.SDP, - }) -} - -func (c *Call) onDCUnpublish(stop []event.SFUTrackDescription, sdp string) { - for _, trackDesc := range stop { - for _, publisher := range c.Publishers { - if publisher.Matches(trackDesc) { - publisher.Stop() - } - } - } - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } - - offer, err := c.PeerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationAnswer, - SDP: offer.SDP, - }) -} - -func (c *Call) onDCAnswer(sdp string) { - c.logger.Info("received DC answer") - - err := c.PeerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, - SDP: sdp, - }) - if err != nil { - c.logger.WithField("sdp", sdp).WithError(err).Error("failed to set remote description - ignoring") - return - } -} - -func (c *Call) onDCAlive() { - c.lastKeepAliveTimestamp = time.Now() -} - -func (c *Call) onDCMetadata() { - c.logger.Info("received DC metadata") - - c.conf.SendUpdatedMetadataFromCall(c.CallID) -} - -func (c *Call) dataChannelHandler(channel *webrtc.DataChannel) { - c.dataChannel = channel - - channel.OnOpen(func() { - c.SendDataChannelMessage(event.SFUMessage{Op: event.SFUOperationMetadata}) - }) - - channel.OnError(func(err error) { - logrus.Fatalf("%s | DC error: %s", c.CallID, err) - }) - - channel.OnMessage(func(marshaledMsg webrtc.DataChannelMessage) { - if !marshaledMsg.IsString { - c.logger.WithField("msg", marshaledMsg).Warn("inbound message is not string - ignoring") - return - } - - msg := &event.SFUMessage{} - if err := json.Unmarshal(marshaledMsg.Data, msg); err != nil { - c.logger.WithField("msg", marshaledMsg).WithError(err).Error("failed to unmarshal - ignoring") - return - } - - if msg.Metadata != nil { - c.conf.Metadata.Update(c.DeviceID, msg.Metadata) - } - - switch msg.Op { - case event.SFUOperationSelect: - c.onDCSelect(msg.Start) - case event.SFUOperationPublish: - c.onDCPublish(msg.SDP) - case event.SFUOperationUnpublish: - c.onDCUnpublish(msg.Stop, msg.SDP) - case event.SFUOperationAnswer: - c.onDCAnswer(msg.SDP) - case event.SFUOperationAlive: - c.onDCAlive() - case event.SFUOperationMetadata: - c.onDCMetadata() - - default: - c.logger.WithField("op", msg.Op).Warn("Unknown operation - ignoring") - // TODO: hook up msg.Stop to unsubscribe from tracks - // TODO: hook cascade back up. - // As we're not an AS, we'd rely on the client - // to send us a "connect" op to tell us how to - // connect to another focus in order to select - // its streams. - } - }) -} - -func (c *Call) negotiationNeededHandler() { - offer, err := c.PeerConnection.CreateOffer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create offer - ignoring") - return - } - - err = c.PeerConnection.SetLocalDescription(offer) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - c.SendDataChannelMessage(event.SFUMessage{ - Op: event.SFUOperationOffer, - SDP: offer.SDP, - }) -} - -func (c *Call) iceCandidateHandler(candidate *webrtc.ICECandidate) { - if candidate == nil { - return - } - - jsonCandidate := candidate.ToJSON() - - candidateEvtContent := &event.Content{ - Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Candidates: []event.CallCandidate{{ - Candidate: jsonCandidate.Candidate, - SDPMLineIndex: int(*jsonCandidate.SDPMLineIndex), - SDPMID: *jsonCandidate.SDPMid, - }}, - }, - } - c.sendToDevice(event.CallCandidates, candidateEvtContent) -} - -func (c *Call) trackHandler(trackRemote *webrtc.TrackRemote) { - NewPublisher(trackRemote, c) - - go c.conf.SendUpdatedMetadataFromCall(c.CallID) -} - -func (c *Call) iceConnectionStateHandler(state webrtc.ICEConnectionState) { - if state == webrtc.ICEConnectionStateCompleted || state == webrtc.ICEConnectionStateConnected { - c.lastKeepAliveTimestamp = time.Now() - go c.CheckKeepAliveTimestamp() - - if !c.sentEndOfCandidates { - candidateEvtContent := &event.Content{ - Parsed: event.CallCandidatesEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Candidates: []event.CallCandidate{{Candidate: ""}}, - }, - } - c.sendToDevice(event.CallCandidates, candidateEvtContent) - c.sentEndOfCandidates = true - } - } -} - -func (c *Call) OnInvite(content *event.CallInviteEventContent) { - c.conf.Metadata.Update(c.DeviceID, content.SDPStreamMetadata) - offer := content.Offer - - peerConnection, err := webrtc.NewPeerConnection(webrtc.Configuration{}) - if err != nil { - logrus.WithError(err).Error("failed to create new peer connection") - } - - c.PeerConnection = peerConnection - - peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - c.trackHandler(track) - }) - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { - c.dataChannelHandler(d) - }) - peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { - c.iceCandidateHandler(candidate) - }) - peerConnection.OnNegotiationNeeded(func() { - c.negotiationNeededHandler() - }) - peerConnection.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { - c.iceConnectionStateHandler(state) - }) - - err = peerConnection.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: offer.SDP, - }) - if err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set remote description - ignoring") - return - } - - answer, err := peerConnection.CreateAnswer(nil) - if err != nil { - c.logger.WithError(err).Error("failed to create answer - ignoring") - return - } - - // TODO: trickle ICE for fast conn setup, rather than block here - gatherComplete := webrtc.GatheringCompletePromise(peerConnection) - - if err = peerConnection.SetLocalDescription(answer); err != nil { - c.logger.WithField("sdp", offer.SDP).WithError(err).Error("failed to set local description - ignoring") - return - } - - <-gatherComplete - - answerEvtContent := &event.Content{ - Parsed: event.CallAnswerEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Answer: event.CallData{ - Type: "answer", - SDP: peerConnection.LocalDescription().SDP, - }, - SDPStreamMetadata: c.conf.Metadata.GetForDevice(c.DeviceID), - }, - } - c.sendToDevice(event.CallAnswer, answerEvtContent) -} - -func (c *Call) OnSelectAnswer(content *event.CallSelectAnswerEventContent) { - selectedPartyID := content.SelectedPartyID - if selectedPartyID != string(c.client.DeviceID) { - c.logger.WithField("selected_party_id", selectedPartyID).Warn("call was answered on a different device") - c.Terminate() - } -} - -func (c *Call) OnHangup() { - c.Terminate() -} - -func (c *Call) OnCandidates(content *event.CallCandidatesEventContent) { - for _, candidate := range content.Candidates { - sdpMLineIndex := uint16(candidate.SDPMLineIndex) - ice := webrtc.ICECandidateInit{ - Candidate: candidate.Candidate, - SDPMid: &candidate.SDPMID, - SDPMLineIndex: &sdpMLineIndex, - UsernameFragment: new(string), - } - - if err := c.PeerConnection.AddICECandidate(ice); err != nil { - c.logger.WithField("content", content).WithError(err).Error("failed to add ICE candidate") - } - } -} - -func (c *Call) Terminate() { - c.logger.Info("terminating call") - - if err := c.PeerConnection.Close(); err != nil { - c.logger.WithError(err).Error("error closing peer connection") - } - - c.conf.mutex.Lock() - delete(c.conf.Calls, c.CallID) - c.conf.mutex.Unlock() - - for _, publisher := range c.Publishers { - publisher.Stop() - } - - for _, subscriber := range c.Subscribers { - subscriber.Unsubscribe() - } - - c.conf.SendUpdatedMetadataFromCall(c.CallID) -} - -func (c *Call) Hangup(reason event.CallHangupReason) { - hangupEvtContent := &event.Content{ - Parsed: event.CallHangupEventContent{ - BaseCallEventContent: event.BaseCallEventContent{ - CallID: c.CallID, - ConfID: c.conf.ConfID, - DeviceID: c.client.DeviceID, - SenderSessionID: c.LocalSessionID, - DestSessionID: c.RemoteSessionID, - PartyID: string(c.client.DeviceID), - Version: event.CallVersion("1"), - }, - Reason: reason, - }, - } - c.sendToDevice(event.CallHangup, hangupEvtContent) - c.Terminate() -} - -func (c *Call) sendToDevice(callType event.Type, content *event.Content) { - evtLogger := c.logger.WithFields(logrus.Fields{ - "type": callType.Type, - }) - - if callType.Type != event.CallCandidates.Type { - evtLogger.Info("sending to device") - } - - toDevice := &mautrix.ReqSendToDevice{ - Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - c.UserID: { - c.DeviceID: content, - }, - }, - } - - // TODO: E2EE - // TODO: to-device reliability - if _, err := c.client.SendToDevice(callType, toDevice); err != nil { - evtLogger.WithField("content", content).WithError(err).Error("error sending to-device") - } -} - -func (c *Call) SendDataChannelMessage(msg event.SFUMessage) { - if c.dataChannel == nil { - return - } - - evtLogger := c.logger.WithFields(logrus.Fields{ - "op": msg.Op, - }) - - if msg.Metadata == nil { - msg.Metadata = c.conf.Metadata.GetForDevice(c.DeviceID) - if msg.Op == event.SFUOperationMetadata && len(msg.Metadata) == 0 { - return - } - } - - marshaled, err := json.Marshal(msg) - if err != nil { - evtLogger.WithField("msg", msg).WithError(err).Error("failed to marshal - ignoring") - return - } - - err = c.dataChannel.SendText(string(marshaled)) - if err != nil { - evtLogger.WithField("msg", msg).WithError(err).Error("failed to send message over DC") - } - - evtLogger.Info("sent message over DC") -} - -func (c *Call) CheckKeepAliveTimestamp() { - timeout := time.Second * time.Duration(c.conf.Config.KeepAliveTimeout) - for range time.Tick(timeout) { - if c.lastKeepAliveTimestamp.Add(timeout).Before(time.Now()) { - if c.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateClosed { - c.logger.WithField("timeout", timeout).Warn("did not get keep-alive message") - c.Hangup(event.CallHangupKeepAliveTimeout) - } - - break - } - } -} - -func (c *Call) RemoveSubscriber(toDelete *Subscriber) bool { - removed := false - newSubscribers := []*Subscriber{} - - c.mutex.Lock() - for _, subscriber := range c.Subscribers { - if subscriber != toDelete { - removed = true - } else { - newSubscribers = append(newSubscribers, subscriber) - } - } - - c.Subscribers = newSubscribers - c.mutex.Unlock() - - return removed -} - -func (c *Call) RemovePublisher(toDelete *Publisher) bool { - removed := false - newPublishers := []*Publisher{} - - c.mutex.Lock() - for _, publisher := range c.Publishers { - if publisher == toDelete { - removed = true - } else { - newPublishers = append(newPublishers, publisher) - } - } - - c.Publishers = newPublishers - c.mutex.Unlock() - - return removed -} diff --git a/src/conference.go b/src/conference.go deleted file mode 100644 index 1d08064..0000000 --- a/src/conference.go +++ /dev/null @@ -1,115 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "sync" - - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ( - ErrNoSuchCall = errors.New("no such call") - ErrFoundExistingCallWithSameSessionAndDeviceID = errors.New("found existing call with equal DeviceID and SessionID") -) - -// Configuration for the group conferences (calls). -type ConferenceConfig struct { - // Keep-alive timeout for WebRTC connections. If no keep-alive has been received - // from the client for this duration, the connection is considered dead. - KeepAliveTimeout int -} - -type Conference struct { - ConfID string - Calls map[string]*Call // By callID - Config *ConferenceConfig // TODO: this must be protected by a mutex actually - - mutex sync.RWMutex - logger *logrus.Entry - Metadata *Metadata -} - -func NewConference(confID string, config *ConferenceConfig) *Conference { - conference := new(Conference) - - conference.Config = config - conference.ConfID = confID - conference.Calls = make(map[string]*Call) - conference.Metadata = NewMetadata(conference) - conference.logger = logrus.WithFields(logrus.Fields{ - "conf_id": confID, - }) - - return conference -} - -func (c *Conference) GetCall(callID string, create bool) (*Call, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - call := c.Calls[callID] - - if call == nil { - if create { - call = NewCall(callID, c) - c.Calls[callID] = call - } else { - return nil, ErrNoSuchCall - } - } - - return call, nil -} - -func (c *Conference) RemoveOldCallsByDeviceAndSessionIDs(deviceID id.DeviceID, sessionID id.SessionID) error { - var err error - - for _, call := range c.Calls { - if call.DeviceID == deviceID { - if call.RemoteSessionID == sessionID { - err = ErrFoundExistingCallWithSameSessionAndDeviceID - } else { - call.Terminate() - } - } - } - - return err -} - -func (c *Conference) SendUpdatedMetadataFromCall(callID string) { - for _, call := range c.Calls { - if call.CallID != callID { - call.SendDataChannelMessage(event.SFUMessage{Op: event.SFUOperationMetadata}) - } - } -} - -func (c *Conference) GetPublishers() []*Publisher { - publishers := []*Publisher{} - - c.mutex.RLock() - for _, call := range c.Calls { - publishers = append(publishers, call.Publishers...) - } - c.mutex.RUnlock() - - return publishers -} diff --git a/src/focus.go b/src/focus.go deleted file mode 100644 index c2851c6..0000000 --- a/src/focus.go +++ /dev/null @@ -1,176 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "strings" - "sync" - - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" -) - -var ErrNoSuchConference = errors.New("no such conf") - -type Confs struct { - confsMu sync.RWMutex - confs map[string]*Conference -} - -type Focus struct { - name string - client *mautrix.Client - confs Confs - logger *logrus.Entry - config *ConferenceConfig -} - -func NewFocus(name string, client *mautrix.Client, config *ConferenceConfig) *Focus { - focus := new(Focus) - - focus.name = name - focus.client = client - focus.confs.confs = make(map[string]*Conference) - focus.logger = logrus.WithFields(logrus.Fields{}) - focus.config = config - - return focus -} - -func (f *Focus) GetConf(confID string, create bool) (*Conference, error) { - f.confs.confsMu.Lock() - defer f.confs.confsMu.Unlock() - conference := f.confs.confs[confID] - - if conference == nil { - if create { - conference = NewConference(confID, f.config) - f.confs.confs[confID] = conference - } else { - return nil, ErrNoSuchConference - } - } - - return conference, nil -} - -func (f *Focus) getExistingCall(confID string, callID string) (*Call, error) { - var ( - conf *Conference - call *Call - err error - ) - - if conf, err = f.GetConf(confID, false); err != nil || conf == nil { - f.logger.Printf("failed to get conf %s: %s", confID, err) - return nil, err - } - - if call, err = conf.GetCall(callID, false); err != nil || call == nil { - f.logger.Printf("failed to get call %s: %s", callID, err) - return nil, err - } - - return call, nil -} - -func (f *Focus) onEvent(_ mautrix.EventSource, evt *event.Event) { - // We only care about to-device events - if evt.Type.Class != event.ToDeviceEventType { - return - } - - evtLogger := f.logger.WithFields(logrus.Fields{ - "type": evt.Type.Type, - "user_id": evt.Sender.String(), - "conf_id": evt.Content.Raw["conf_id"], - }) - - if !strings.HasPrefix(evt.Type.Type, "m.call.") && !strings.HasPrefix(evt.Type.Type, "org.matrix.call.") { - evtLogger.Warn("received non-call to-device event") - return - } else if evt.Type.Type != event.ToDeviceCallCandidates.Type && evt.Type.Type != event.ToDeviceCallSelectAnswer.Type { - evtLogger.Info("received to-device event") - } - - if evt.Content.Raw["dest_session_id"] != localSessionID { - evtLogger.WithField("dest_session_id", localSessionID).Warn("SessionID does not match our SessionID - ignoring") - return - } - - var ( - conf *Conference - call *Call - err error - ) - - switch evt.Type.Type { - case event.ToDeviceCallInvite.Type: - invite := evt.Content.AsCallInvite() - if conf, err = f.GetConf(invite.ConfID, true); err != nil { - evtLogger.WithError(err).WithFields(logrus.Fields{ - "conf_id": invite.ConfID, - }).Error("failed to create conf") - - return - } - - if err := conf.RemoveOldCallsByDeviceAndSessionIDs(invite.DeviceID, invite.SenderSessionID); err != nil { - evtLogger.WithError(err).Error("error removing old calls - ignoring call") - return - } - - if call, err = conf.GetCall(invite.CallID, true); err != nil || call == nil { - evtLogger.WithError(err).Error("failed to create call") - return - } - - call.InitWithInvite(evt, f.client) - call.OnInvite(invite) - case event.ToDeviceCallCandidates.Type: - candidates := evt.Content.AsCallCandidates() - if call, err = f.getExistingCall(candidates.ConfID, candidates.CallID); err != nil { - return - } - - call.OnCandidates(candidates) - case event.ToDeviceCallSelectAnswer.Type: - selectAnswer := evt.Content.AsCallSelectAnswer() - if call, err = f.getExistingCall(selectAnswer.ConfID, selectAnswer.CallID); err != nil { - return - } - - call.OnSelectAnswer(selectAnswer) - case event.ToDeviceCallHangup.Type: - hangup := evt.Content.AsCallHangup() - if call, err = f.getExistingCall(hangup.ConfID, hangup.CallID); err != nil { - return - } - - call.OnHangup() - // Events we don't care about - case event.ToDeviceCallNegotiate.Type: - evtLogger.Warn("ignoring event as it should be handled over DC") - case event.ToDeviceCallReject.Type: - case event.ToDeviceCallAnswer.Type: - evtLogger.Warn("ignoring event as we are always the ones answering") - default: - evtLogger.Warn("ignoring unrecognised event") - } -} diff --git a/src/logger.go b/src/logger.go deleted file mode 100644 index 2d849fc..0000000 --- a/src/logger.go +++ /dev/null @@ -1,68 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "fmt" - - "github.com/sirupsen/logrus" -) - -type CustomTextFormatter struct { - logrus.TextFormatter - logTime bool -} - -func (f *CustomTextFormatter) Format(entry *logrus.Entry) ([]byte, error) { - // TODO: Use colors and make it pretty - level := entry.Level - timestamp := entry.Time.Format("2006-01-02 15:04:05") - confID := entry.Data["conf_id"] - userID := entry.Data["user_id"] - - logLine := fmt.Sprintf("%s|", level) - - if f.logTime { - logLine += fmt.Sprintf("%s|", timestamp) - } - - if confID != nil { - logLine += fmt.Sprintf("%v|", confID) - } - - if userID != nil { - logLine += fmt.Sprintf("%v|", userID) - } - - logLine += fmt.Sprintf(" %v ", entry.Message) - - fields := "" - - for key, value := range entry.Data { - if key != "conf_id" && key != "user_id" { - fields += fmt.Sprintf("%v=%v ", key, value) - } - } - - if fields != "" { - logLine += fmt.Sprintf("| %s", fields) - } - - logLine += "\n" - - return []byte(logLine), nil -} diff --git a/src/main.go b/src/main.go deleted file mode 100644 index 80dbaaf..0000000 --- a/src/main.go +++ /dev/null @@ -1,124 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "flag" - "os" - "os/signal" - "runtime" - "runtime/pprof" - "syscall" - - "github.com/sirupsen/logrus" -) - -func initCPUProfiling(cpuProfile *string) func() { - logrus.Info("initializing CPU profiling") - - file, err := os.Create(*cpuProfile) - if err != nil { - logrus.WithError(err).Fatal("could not create CPU profile") - } - - if err := pprof.StartCPUProfile(file); err != nil { - logrus.WithError(err).Fatal("could not start CPU profile") - } - - return func() { - pprof.StopCPUProfile() - - if err := file.Close(); err != nil { - logrus.WithError(err).Fatal("could not close CPU profile") - } - } -} - -func initMemoryProfiling(memProfile *string) func() { - logrus.Info("initializing memory profiling") - - return func() { - file, err := os.Create(*memProfile) - if err != nil { - logrus.WithError(err).Fatal("could not create memory profile") - } - - runtime.GC() - - if err := pprof.WriteHeapProfile(file); err != nil { - logrus.WithError(err).Fatal("could not write memory profile") - } - - if err = file.Close(); err != nil { - logrus.WithError(err).Fatal("could not close memory profile") - } - } -} - -func initLogging(logTime *bool) { - formatter := new(CustomTextFormatter) - - formatter.logTime = *logTime - - logrus.SetFormatter(formatter) -} - -func killListener(c chan os.Signal, beforeExit []func()) { - <-c - - for _, function := range beforeExit { - function() - } - - defer os.Exit(0) -} - -func main() { - var ( - logTime = flag.Bool("logTime", false, "whether or not to print time and date in logs") - configFilePath = flag.String("config", "config.yaml", "configuration file path") - cpuProfile = flag.String("cpuProfile", "", "write CPU profile to `file`") - memProfile = flag.String("memProfile", "", "write memory profile to `file`") - ) - - flag.Parse() - - initLogging(logTime) - - beforeExit := []func(){} - if *cpuProfile != "" { - beforeExit = append(beforeExit, initCPUProfiling(cpuProfile)) - } - - if *memProfile != "" { - beforeExit = append(beforeExit, initMemoryProfiling(memProfile)) - } - - // try to handle os interrupt(signal terminated) - //nolint:gomnd - c := make(chan os.Signal, 2) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - - go killListener(c, beforeExit) - - config, err := loadConfig(*configFilePath) - if err != nil { - logrus.WithError(err).Fatal("could not load config") - } - - InitMatrix(config) -} diff --git a/src/matrix.go b/src/matrix.go deleted file mode 100644 index db493ba..0000000 --- a/src/matrix.go +++ /dev/null @@ -1,65 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "fmt" - - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix" -) - -const localSessionID = "sfu" - -func InitMatrix(config *Config) { - client, err := mautrix.NewClient(config.HomeserverURL, config.UserID, config.AccessToken) - if err != nil { - logrus.WithError(err).Fatal("Failed to create client") - } - - whoami, err := client.Whoami() - if err != nil { - logrus.WithError(err).Fatal("Failed to identify SFU user") - } - - if config.UserID != whoami.UserID { - logrus.WithField("user_id", config.UserID).Fatal("Access token is for the wrong user") - } - - logrus.WithField("device_id", whoami.DeviceID).Info("Identified SFU as DeviceID") - client.DeviceID = whoami.DeviceID - - focus := NewFocus( - fmt.Sprintf("%s (%s)", config.UserID, client.DeviceID), - client, - &ConferenceConfig{KeepAliveTimeout: config.KeepAliveTimeout}, - ) - - syncer, ok := client.Syncer.(*mautrix.DefaultSyncer) - if !ok { - logrus.Panic("Syncer is not DefaultSyncer") - } - - syncer.ParseEventContent = true - - // TODO: E2EE - syncer.OnEvent(focus.onEvent) - - if err = client.Sync(); err != nil { - logrus.WithError(err).Panic("Sync failed") - } -} diff --git a/src/metadata.go b/src/metadata.go deleted file mode 100644 index 3ae0ed6..0000000 --- a/src/metadata.go +++ /dev/null @@ -1,104 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "sync" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type Metadata struct { - mutex sync.RWMutex - data event.CallSDPStreamMetadata - conference *Conference -} - -func NewMetadata(conference *Conference) *Metadata { - metadata := new(Metadata) - - metadata.data = make(event.CallSDPStreamMetadata) - metadata.conference = conference - - return metadata -} - -func (m *Metadata) Update(deviceID id.DeviceID, metadata event.CallSDPStreamMetadata) { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Update existing and add new - for streamID, info := range metadata { - m.data[streamID] = info - } - // Remove removed - for streamID, info := range m.data { - _, exists := metadata[streamID] - if info.DeviceID == deviceID && !exists { - delete(m.data, streamID) - } - } -} - -func (m *Metadata) RemoveByDevice(deviceID id.DeviceID) { - m.mutex.Lock() - defer m.mutex.Unlock() - - for streamID, info := range m.data { - if info.DeviceID == deviceID { - delete(m.data, streamID) - } - } -} - -// Get metadata to send to deviceID. This will not include the device's own -// metadata and metadata which includes tracks which we have not received yet. -func (m *Metadata) GetForDevice(deviceID id.DeviceID) event.CallSDPStreamMetadata { - metadata := make(event.CallSDPStreamMetadata) - - m.mutex.RLock() - defer m.mutex.RUnlock() - - for _, publisher := range m.conference.GetPublishers() { - if deviceID == publisher.Call.DeviceID { - continue - } - - streamID := publisher.Track.StreamID() - trackID := publisher.Track.ID() - - info, exists := metadata[streamID] - if exists { - info.Tracks[publisher.Track.ID()] = event.CallSDPStreamMetadataTrack{} - metadata[streamID] = info - } else { - metadata[streamID] = event.CallSDPStreamMetadataObject{ - UserID: publisher.Call.UserID, - DeviceID: publisher.Call.DeviceID, - Purpose: m.data[streamID].Purpose, - AudioMuted: m.data[streamID].AudioMuted, - VideoMuted: m.data[streamID].VideoMuted, - Tracks: event.CallSDPStreamMetadataTracks{ - trackID: {}, - }, - } - } - } - - return metadata -} diff --git a/src/publisher.go b/src/publisher.go deleted file mode 100644 index 3058860..0000000 --- a/src/publisher.go +++ /dev/null @@ -1,195 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "io" - "sync" - "sync/atomic" - "time" - - "github.com/pion/rtcp" - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" - "maunium.net/go/mautrix/event" -) - -const ( - minimalPLIInterval = time.Millisecond * 500 - bufferSize = 1500 -) - -type Publisher struct { - Track *webrtc.TrackRemote - Call *Call - - mutex sync.RWMutex - logger *logrus.Entry - subscribers []*Subscriber - - lastPLI atomic.Int64 -} - -func NewPublisher( - track *webrtc.TrackRemote, - call *Call, -) *Publisher { - publisher := new(Publisher) - - publisher.Track = track - publisher.Call = call - - publisher.subscribers = []*Subscriber{} - publisher.logger = call.logger.WithFields(logrus.Fields{ - "track_id": track.ID(), - "track_kind": track.Kind(), - "stream_id": track.StreamID(), - }) - - call.mutex.Lock() - call.Publishers = append(call.Publishers, publisher) - call.mutex.Unlock() - - go publisher.WriteToSubscribers() - - publisher.logger.Info("published track") - - return publisher -} - -func (p *Publisher) Subscribe(call *Call) { - subscriber := NewSubscriber(call) - subscriber.Subscribe(p) - p.AddSubscriber(subscriber) -} - -func (p *Publisher) Stop() { - removed := p.Call.RemovePublisher(p) - - if len(p.subscribers) == 0 && !removed { - return - } - - for _, subscriber := range p.subscribers { - subscriber.Unsubscribe() - p.RemoveSubscriber(subscriber) - } - - p.logger.Info("unpublished track") -} - -func (p *Publisher) AddSubscriber(subscriber *Subscriber) { - p.mutex.Lock() - defer p.mutex.Unlock() - p.subscribers = append(p.subscribers, subscriber) -} - -func (p *Publisher) RemoveSubscriber(toDelete *Subscriber) { - newSubscribers := []*Subscriber{} - - p.mutex.Lock() - for _, subscriber := range p.subscribers { - if subscriber != toDelete { - newSubscribers = append(newSubscribers, subscriber) - } - } - - p.subscribers = newSubscribers - p.mutex.Unlock() -} - -func (p *Publisher) Matches(trackDescription event.SFUTrackDescription) bool { - if p.Track.ID() != trackDescription.TrackID { - return false - } - - if p.Track.StreamID() != trackDescription.StreamID { - return false - } - - return true -} - -func (p *Publisher) WriteRTCP(packets []rtcp.Packet) { - packetsToSend := []rtcp.Packet{} - readSSRC := uint32(p.Track.SSRC()) - - for _, packet := range packets { - switch typedPacket := packet.(type) { - // We mung the packets here, so that the SSRCs match what the - // receiver expects: - // The media SSRC is the SSRC of the media about which the packet is - // reporting; therefore, we mung it to be the SSRC of the publishing - // participant's track. Without this, it would be SSRC of the SFU's - // track which isn't right - case *rtcp.PictureLossIndication: - // Since we sometimes spam the sender with PLIs, make sure we don't send - // them way too often - if time.Now().UnixNano()-p.lastPLI.Load() < minimalPLIInterval.Nanoseconds() { - continue - } - - p.lastPLI.Store(time.Now().UnixNano()) - - typedPacket.MediaSSRC = readSSRC - packetsToSend = append(packetsToSend, typedPacket) - case *rtcp.FullIntraRequest: - typedPacket.MediaSSRC = readSSRC - packetsToSend = append(packetsToSend, typedPacket) - } - - packetsToSend = append(packetsToSend, packet) - } - - if len(packetsToSend) != 0 { - if err := p.Call.PeerConnection.WriteRTCP(packetsToSend); err != nil { - if !errors.Is(err, io.ErrClosedPipe) { - p.logger.WithError(err).Warn("failed to write RTCP on track") - } - } - } -} - -func (p *Publisher) WriteToSubscribers() { - buff := make([]byte, bufferSize) - - for { - index, _, err := p.Track.Read(buff) - if err != nil { - if errors.Is(err, io.EOF) { - p.Stop() - return - } - - p.logger.WithError(err).Warn("failed to read track") - } - - for _, subscriber := range p.subscribers { - if _, err = subscriber.Track.Write(buff[:index]); err != nil { - if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - subscriber.Unsubscribe() - p.RemoveSubscriber(subscriber) - - return - } - - p.logger.WithError(err).Warn("failed to write to track") - } - } - } -} diff --git a/src/subscriber.go b/src/subscriber.go deleted file mode 100644 index fbe9235..0000000 --- a/src/subscriber.go +++ /dev/null @@ -1,136 +0,0 @@ -/* -Copyright 2022 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package main - -import ( - "errors" - "io" - "sync" - - "github.com/pion/webrtc/v3" - "github.com/sirupsen/logrus" -) - -type Subscriber struct { - Track *webrtc.TrackLocalStaticRTP - - mutex sync.RWMutex - logger *logrus.Entry - call *Call - sender *webrtc.RTPSender - publisher *Publisher -} - -func NewSubscriber(call *Call) *Subscriber { - subscriber := new(Subscriber) - - subscriber.call = call - subscriber.logger = call.logger - - call.mutex.Lock() - call.Subscribers = append(call.Subscribers, subscriber) - call.mutex.Unlock() - - return subscriber -} - -func (s *Subscriber) initLoggingWithTrack(track *webrtc.TrackRemote) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.logger = s.call.logger.WithFields(logrus.Fields{ - "track_id": (*track).ID(), - "track_kind": (*track).Kind(), - "stream_id": (*track).StreamID(), - }) -} - -func (s *Subscriber) Subscribe(publisher *Publisher) { - s.initLoggingWithTrack(publisher.Track) - - if s.publisher != nil { - s.logger.Error("cannot subscribe, if we already are") - } - - track, err := webrtc.NewTrackLocalStaticRTP( - publisher.Track.Codec().RTPCodecCapability, - publisher.Track.ID(), - publisher.Track.StreamID(), - ) - if err != nil { - s.logger.WithError(err).Error("failed to create local static RTP track") - } - - sender, err := s.call.PeerConnection.AddTrack(track) - if err != nil { - s.logger.WithError(err).Error("failed to add track to peer connection") - } - - s.mutex.Lock() - s.Track = track - s.sender = sender - s.publisher = publisher - s.mutex.Unlock() - - if s.Track.Kind() == webrtc.RTPCodecTypeVideo { - go s.forwardRTCP() - } - - publisher.AddSubscriber(s) - - s.logger.Info("subscribed") -} - -func (s *Subscriber) Unsubscribe() { - if s.publisher == nil { - return - } - - if s.call.PeerConnection.ConnectionState() != webrtc.PeerConnectionStateClosed { - err := s.call.PeerConnection.RemoveTrack(s.sender) - if err != nil { - s.logger.WithError(err).Error("failed to remove track") - } - } - - s.call.RemoveSubscriber(s) - - s.mutex.Lock() - s.publisher = nil - s.mutex.Unlock() - - s.logger.Info("unsubscribed") -} - -func (s *Subscriber) forwardRTCP() { - for { - // If we unsubscribed, stop forwarding RTCP packets - if s.publisher == nil { - return - } - - packets, _, err := s.sender.ReadRTCP() - if err != nil { - if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.EOF) { - return - } - - s.logger.WithError(err).Warn("failed to read RTCP on track") - } - - s.publisher.WriteRTCP(packets) - } -}