diff --git a/engine.go b/engine.go index 7a2dd697..b41105e0 100644 --- a/engine.go +++ b/engine.go @@ -173,10 +173,6 @@ func (e *RTCEngine) TrackPublishedChan() <-chan *livekit.TrackPublishedResponse } func (e *RTCEngine) setRTT(rtt uint32) { - if pc := e.publisher; pc != nil { - pc.SetRTT(rtt) - } - if pc := e.subscriber; pc != nil { pc.SetRTT(rtt) } @@ -193,6 +189,7 @@ func (e *RTCEngine) configure(res *livekit.JoinResponse) error { Configuration: configuration, RetransmitBufferSize: e.connParams.RetransmitBufferSize, Pacer: e.connParams.Pacer, + OnRTTUpdate: e.setRTT, }); err != nil { return err } diff --git a/localparticipant.go b/localparticipant.go index 937ba257..1c29ff2e 100644 --- a/localparticipant.go +++ b/localparticipant.go @@ -55,9 +55,6 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl } pub := NewLocalTrackPublication(kind, track, *opts, p.engine.client) - pub.OnRttUpdate(func(rtt uint32) { - p.engine.setRTT(rtt) - }) pub.onMuteChanged = p.onTrackMuted req := &livekit.AddTrackRequest{ @@ -107,7 +104,9 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl return nil, err } - pub.setSender(transceiver.Sender()) + // LocalSampleTrack will consume rtcp packets so we don't need to consume again + _, isSampleTrack := track.(*LocalSampleTrack) + pub.setSender(transceiver.Sender(), !isSampleTrack) pub.updateInfo(pubRes.Track) p.addPublication(pub) @@ -196,7 +195,7 @@ func (p *LocalParticipant) PublishSimulcastTrack(tracks []*LocalSampleTrack, opt return nil, err } sender = transceiver.Sender() - pub.setSender(sender) + pub.setSender(sender, false) } else { if err = sender.AddEncoding(st); err != nil { return nil, err diff --git a/localsampletrack.go b/localsampletrack.go index f75933a7..8c9fb0b6 100644 --- a/localsampletrack.go +++ b/localsampletrack.go @@ -387,6 +387,7 @@ func (s *LocalSampleTrack) rtcpWorker(rtcpReader interceptor.RTCPReader) { pkts, err := rtcp.Unmarshal(b[:i]) if err != nil { + logger.Warnw("could not unmarshal rtcp", err) return } for _, packet := range pkts { diff --git a/pkg/interceptor/rttinteceptor.go b/pkg/interceptor/rttinteceptor.go new file mode 100644 index 00000000..0b27a3db --- /dev/null +++ b/pkg/interceptor/rttinteceptor.go @@ -0,0 +1,66 @@ +package interceptor + +import ( + "github.com/livekit/mediatransportutil" + "github.com/pion/interceptor" + "github.com/pion/rtcp" +) + +type RTTInterceptorFactory struct { + onRttUpdate func(rtt uint32) +} + +func NewRTTInterceptorFactory(onRttUpdate func(rtt uint32)) *RTTInterceptorFactory { + return &RTTInterceptorFactory{ + onRttUpdate: onRttUpdate, + } +} + +func (r *RTTInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { + return NewRTTInterceptor(r.onRttUpdate), nil +} + +type RTTInterceptor struct { + interceptor.NoOp + + onRttUpdate func(rtt uint32) +} + +func NewRTTInterceptor(onRttUpdate func(rtt uint32)) *RTTInterceptor { + return &RTTInterceptor{ + onRttUpdate: onRttUpdate, + } +} + +func (r *RTTInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + i, attr, err := reader.Read(b, a) + if err != nil { + return 0, nil, err + } + + if attr == nil { + attr = make(interceptor.Attributes) + } + pkts, err := attr.GetRTCPPackets(b[:i]) + if err != nil { + return 0, nil, err + } + + rttCaculate: + for _, packet := range pkts { + if rr, ok := packet.(*rtcp.ReceiverReport); ok { + for _, report := range rr.Reports { + rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&report) + if err == nil && rtt != 0 { + r.onRttUpdate(rtt) + } + + break rttCaculate + } + } + } + + return i, attr, err + }) +} diff --git a/publication.go b/publication.go index 36265747..fbde88f9 100644 --- a/publication.go +++ b/publication.go @@ -22,7 +22,6 @@ import ( "go.uber.org/atomic" "google.golang.org/protobuf/proto" - "github.com/livekit/mediatransportutil" "github.com/livekit/protocol/livekit" ) @@ -240,7 +239,6 @@ type LocalTrackPublication struct { sender *webrtc.RTPSender // set for simulcasted tracks simulcastTracks map[livekit.VideoQuality]*LocalSampleTrack - onRttUpdate func(uint32) opts TrackPublicationOptions onMuteChanged func(*LocalTrackPublication, bool) } @@ -315,43 +313,27 @@ func (p *LocalTrackPublication) addSimulcastTrack(st *LocalSampleTrack) { } } -func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender) { +func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender, consumeRTCP bool) { p.lock.Lock() p.sender = sender p.lock.Unlock() + if !consumeRTCP { + return + } + + // consume RTCP packets so interceptors can handle them (rtt, nacks...) go func() { for { - packets, _, err := sender.ReadRTCP() + _, _, err := sender.ReadRTCP() if err != nil { // pipe closed return } - - rttCaculate: - for _, packet := range packets { - if rr, ok := packet.(*rtcp.ReceiverReport); ok { - for _, r := range rr.Reports { - rr.Reports = append(rr.Reports, r) - rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&r) - if err == nil && rtt != 0 && p.onRttUpdate != nil { - p.onRttUpdate(rtt) - } - - break rttCaculate - } - } - } } }() } -func (p *LocalTrackPublication) OnRttUpdate(cb func(uint32)) { - p.lock.Lock() - p.onRttUpdate = cb - p.lock.Unlock() -} - func (p *LocalTrackPublication) CloseTrack() { for _, st := range p.simulcastTracks { st.Close() diff --git a/transport.go b/transport.go index 3bf26fca..82027d4d 100644 --- a/transport.go +++ b/transport.go @@ -55,6 +55,7 @@ type PCTransport struct { nackGenerator *sdkinterceptor.NackGeneratorInterceptorFactory onRemoteDescriptionSettled func() error + onRTTUpdate func(rtt uint32) OnOffer func(description webrtc.SessionDescription) } @@ -64,6 +65,7 @@ type PCTransportParams struct { RetransmitBufferSize uint16 Pacer pacer.Factory + OnRTTUpdate func(rtt uint32) } func NewPCTransport(params PCTransportParams) (*PCTransport, error) { @@ -86,6 +88,11 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) { i := &interceptor.Registry{} + t := &PCTransport{ + debouncedNegotiate: debounce.New(negotiationFrequency), + onRTTUpdate: params.OnRTTUpdate, + } + // nack interceptor generator := &sdkinterceptor.NackGeneratorInterceptorFactory{} var generatorOption []nack.ResponderOption @@ -119,6 +126,10 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) { i.Add(sdkinterceptor.NewLimitSizeInterceptorFactory()) + if params.OnRTTUpdate != nil { + i.Add(sdkinterceptor.NewRTTInterceptorFactory(t.handleRTTUpdate)) + } + se := webrtc.SettingEngine{} se.SetSRTPProtectionProfiles(dtls.SRTP_AEAD_AES_128_GCM, dtls.SRTP_AES128_CM_HMAC_SHA1_80) se.SetDTLSRetransmissionInterval(dtlsRetransmissionInterval) @@ -130,17 +141,22 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) { return nil, err } - t := &PCTransport{ - pc: pc, - debouncedNegotiate: debounce.New(negotiationFrequency), - nackGenerator: generator, - } + t.pc = pc + t.nackGenerator = generator pc.OnICEGatheringStateChange(t.onICEGatheringStateChange) return t, nil } +func (t *PCTransport) handleRTTUpdate(rtt uint32) { + t.SetRTT(rtt) + + if t.onRTTUpdate != nil { + t.onRTTUpdate(rtt) + } +} + func (t *PCTransport) onICEGatheringStateChange(state webrtc.ICEGathererState) { if state != webrtc.ICEGathererStateComplete { return