Skip to content

Commit

Permalink
Fix pli missed cause by two goroutine compete rtcp reader
Browse files Browse the repository at this point in the history
There were two goroutine to read rtcp when publishing a
LocalSampleTrack, one is publication to calculate rtt
and the other is LocalSampleTrack itself, cause pli hander
rely on LocalSampleTrack's rtcp callback might miss pli
request.
  • Loading branch information
cnderrauber committed Jan 8, 2024
1 parent 12e24f8 commit c7a4620
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 39 deletions.
5 changes: 1 addition & 4 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,6 @@ func (e *RTCEngine) TrackPublishedChan() <-chan *livekit.TrackPublishedResponse
}

func (e *RTCEngine) setRTT(rtt uint32) {
if pc := e.publisher; pc != nil {
pc.SetRTT(rtt)
}

if pc := e.subscriber; pc != nil {
pc.SetRTT(rtt)
}
Expand All @@ -193,6 +189,7 @@ func (e *RTCEngine) configure(res *livekit.JoinResponse) error {
Configuration: configuration,
RetransmitBufferSize: e.connParams.RetransmitBufferSize,
Pacer: e.connParams.Pacer,
OnRTTUpdate: e.setRTT,
}); err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions localparticipant.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl
}

pub := NewLocalTrackPublication(kind, track, *opts, p.engine.client)
pub.OnRttUpdate(func(rtt uint32) {
p.engine.setRTT(rtt)
})
pub.onMuteChanged = p.onTrackMuted

req := &livekit.AddTrackRequest{
Expand Down Expand Up @@ -107,7 +104,9 @@ func (p *LocalParticipant) PublishTrack(track webrtc.TrackLocal, opts *TrackPubl
return nil, err
}

pub.setSender(transceiver.Sender())
// LocalSampleTrack will consume rtcp packets so we don't need to consume again
_, isSampleTrack := track.(*LocalSampleTrack)
pub.setSender(transceiver.Sender(), !isSampleTrack)

pub.updateInfo(pubRes.Track)
p.addPublication(pub)
Expand Down Expand Up @@ -196,7 +195,7 @@ func (p *LocalParticipant) PublishSimulcastTrack(tracks []*LocalSampleTrack, opt
return nil, err
}
sender = transceiver.Sender()
pub.setSender(sender)
pub.setSender(sender, false)
} else {
if err = sender.AddEncoding(st); err != nil {
return nil, err
Expand Down
1 change: 1 addition & 0 deletions localsampletrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ func (s *LocalSampleTrack) rtcpWorker(rtcpReader interceptor.RTCPReader) {

pkts, err := rtcp.Unmarshal(b[:i])
if err != nil {
logger.Warnw("could not unmarshal rtcp", err)
return
}
for _, packet := range pkts {
Expand Down
66 changes: 66 additions & 0 deletions pkg/interceptor/rttinteceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package interceptor

import (
"github.com/livekit/mediatransportutil"
"github.com/pion/interceptor"
"github.com/pion/rtcp"
)

type RTTInterceptorFactory struct {
onRttUpdate func(rtt uint32)
}

func NewRTTInterceptorFactory(onRttUpdate func(rtt uint32)) *RTTInterceptorFactory {
return &RTTInterceptorFactory{
onRttUpdate: onRttUpdate,
}
}

func (r *RTTInterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) {
return NewRTTInterceptor(r.onRttUpdate), nil
}

type RTTInterceptor struct {
interceptor.NoOp

onRttUpdate func(rtt uint32)
}

func NewRTTInterceptor(onRttUpdate func(rtt uint32)) *RTTInterceptor {
return &RTTInterceptor{
onRttUpdate: onRttUpdate,
}
}

func (r *RTTInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader {
return interceptor.RTCPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) {
i, attr, err := reader.Read(b, a)
if err != nil {
return 0, nil, err
}

if attr == nil {
attr = make(interceptor.Attributes)
}
pkts, err := attr.GetRTCPPackets(b[:i])
if err != nil {
return 0, nil, err
}

rttCaculate:
for _, packet := range pkts {
if rr, ok := packet.(*rtcp.ReceiverReport); ok {
for _, report := range rr.Reports {
rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&report)
if err == nil && rtt != 0 {
r.onRttUpdate(rtt)
}

break rttCaculate
}
}
}

return i, attr, err
})
}
32 changes: 7 additions & 25 deletions publication.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"go.uber.org/atomic"
"google.golang.org/protobuf/proto"

"github.com/livekit/mediatransportutil"
"github.com/livekit/protocol/livekit"
)

Expand Down Expand Up @@ -240,7 +239,6 @@ type LocalTrackPublication struct {
sender *webrtc.RTPSender
// set for simulcasted tracks
simulcastTracks map[livekit.VideoQuality]*LocalSampleTrack
onRttUpdate func(uint32)
opts TrackPublicationOptions
onMuteChanged func(*LocalTrackPublication, bool)
}
Expand Down Expand Up @@ -315,43 +313,27 @@ func (p *LocalTrackPublication) addSimulcastTrack(st *LocalSampleTrack) {
}
}

func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender) {
func (p *LocalTrackPublication) setSender(sender *webrtc.RTPSender, consumeRTCP bool) {
p.lock.Lock()
p.sender = sender
p.lock.Unlock()

if !consumeRTCP {
return
}

// consume RTCP packets so interceptors can handle them (rtt, nacks...)
go func() {
for {
packets, _, err := sender.ReadRTCP()
_, _, err := sender.ReadRTCP()
if err != nil {
// pipe closed
return
}

rttCaculate:
for _, packet := range packets {
if rr, ok := packet.(*rtcp.ReceiverReport); ok {
for _, r := range rr.Reports {
rr.Reports = append(rr.Reports, r)
rtt, err := mediatransportutil.GetRttMsFromReceiverReportOnly(&r)
if err == nil && rtt != 0 && p.onRttUpdate != nil {
p.onRttUpdate(rtt)
}

break rttCaculate
}
}
}
}
}()
}

func (p *LocalTrackPublication) OnRttUpdate(cb func(uint32)) {
p.lock.Lock()
p.onRttUpdate = cb
p.lock.Unlock()
}

func (p *LocalTrackPublication) CloseTrack() {
for _, st := range p.simulcastTracks {
st.Close()
Expand Down
26 changes: 21 additions & 5 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ type PCTransport struct {
nackGenerator *sdkinterceptor.NackGeneratorInterceptorFactory

onRemoteDescriptionSettled func() error
onRTTUpdate func(rtt uint32)

OnOffer func(description webrtc.SessionDescription)
}
Expand All @@ -64,6 +65,7 @@ type PCTransportParams struct {

RetransmitBufferSize uint16
Pacer pacer.Factory
OnRTTUpdate func(rtt uint32)
}

func NewPCTransport(params PCTransportParams) (*PCTransport, error) {
Expand All @@ -86,6 +88,11 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {

i := &interceptor.Registry{}

t := &PCTransport{
debouncedNegotiate: debounce.New(negotiationFrequency),
onRTTUpdate: params.OnRTTUpdate,
}

// nack interceptor
generator := &sdkinterceptor.NackGeneratorInterceptorFactory{}
var generatorOption []nack.ResponderOption
Expand Down Expand Up @@ -119,6 +126,10 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {

i.Add(sdkinterceptor.NewLimitSizeInterceptorFactory())

if params.OnRTTUpdate != nil {
i.Add(sdkinterceptor.NewRTTInterceptorFactory(t.handleRTTUpdate))
}

se := webrtc.SettingEngine{}
se.SetSRTPProtectionProfiles(dtls.SRTP_AEAD_AES_128_GCM, dtls.SRTP_AES128_CM_HMAC_SHA1_80)
se.SetDTLSRetransmissionInterval(dtlsRetransmissionInterval)
Expand All @@ -130,17 +141,22 @@ func NewPCTransport(params PCTransportParams) (*PCTransport, error) {
return nil, err
}

t := &PCTransport{
pc: pc,
debouncedNegotiate: debounce.New(negotiationFrequency),
nackGenerator: generator,
}
t.pc = pc
t.nackGenerator = generator

pc.OnICEGatheringStateChange(t.onICEGatheringStateChange)

return t, nil
}

func (t *PCTransport) handleRTTUpdate(rtt uint32) {
t.SetRTT(rtt)

if t.onRTTUpdate != nil {
t.onRTTUpdate(rtt)
}
}

func (t *PCTransport) onICEGatheringStateChange(state webrtc.ICEGathererState) {
if state != webrtc.ICEGathererStateComplete {
return
Expand Down

0 comments on commit c7a4620

Please sign in to comment.