Skip to content

Commit

Permalink
Update RtxSSRC for simulcast track remote
Browse files Browse the repository at this point in the history
Fix pion#2751, updates remote track's rtx ssrc for
simulcast track doesn't contain rtx ssrc in sdp
since readRTX relies on rtx ssrc to determine if
it has a rtx stream.
  • Loading branch information
cnderrauber committed Apr 26, 2024
1 parent 7be0482 commit e7cf3ba
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 0 deletions.
219 changes: 219 additions & 0 deletions peerconnection_media_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -24,6 +26,7 @@ import (
"github.com/pion/rtp"
"github.com/pion/sdp/v3"
"github.com/pion/transport/v2/test"
"github.com/pion/webrtc/v3/internal/util"
"github.com/pion/webrtc/v3/pkg/media"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -1361,6 +1364,222 @@ func TestPeerConnection_Simulcast(t *testing.T) {
})
}

type simulcastTestTrackLocal struct {
*TrackLocalStaticRTP
}

// don't use ssrc&payload in bindings to let the test write different stream packets.
func (s *simulcastTestTrackLocal) WriteRTP(pkt *rtp.Packet) error {
packet := getPacketAllocationFromPool()

defer resetPacketPoolAllocation(packet)

*packet = *pkt

s.mu.RLock()
defer s.mu.RUnlock()

writeErrs := []error{}

for _, b := range s.bindings {
if _, err := b.writeStream.WriteRTP(&packet.Header, packet.Payload); err != nil {
writeErrs = append(writeErrs, err)
}
}

return util.FlattenErrs(writeErrs)
}

func TestPeerConnection_Simulcast_RTX(t *testing.T) {
lim := test.TimeOut(time.Second * 30)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

rids := []string{"a", "b"}
pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)

vp8WriterAStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0]))
assert.NoError(t, err)

vp8WriterBStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1]))
assert.NoError(t, err)

vp8WriterA, vp8WriterB := &simulcastTestTrackLocal{vp8WriterAStatic}, &simulcastTestTrackLocal{vp8WriterBStatic}

sender, err := pcOffer.AddTrack(vp8WriterA)
assert.NoError(t, err)
assert.NotNil(t, sender)

assert.NoError(t, sender.AddEncoding(vp8WriterB))

var ridMapLock sync.RWMutex
ridMap := map[string]int{}

assertRidCorrect := func(t *testing.T) {
ridMapLock.Lock()
defer ridMapLock.Unlock()

for _, rid := range rids {
assert.Equal(t, ridMap[rid], 1)
}
assert.Equal(t, len(ridMap), 2)
}

ridsFullfilled := func() bool {
ridMapLock.Lock()
defer ridMapLock.Unlock()

ridCount := len(ridMap)
return ridCount == 2
}

var rtxPacketRead atomic.Int32
var wg sync.WaitGroup
wg.Add(2)

pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) {
ridMapLock.Lock()
ridMap[trackRemote.RID()] = ridMap[trackRemote.RID()] + 1
ridMapLock.Unlock()

defer wg.Done()

for {
_, attr, rerr := trackRemote.ReadRTP()
if rerr != nil {
break
}
if pt, ok := attr.Get(AttributeRtxPayloadType).(byte); ok {
if pt == 97 {
rtxPacketRead.Add(1)
}
}
}
})

parameters := sender.GetParameters()
assert.Equal(t, "a", parameters.Encodings[0].RID)
assert.Equal(t, "b", parameters.Encodings[1].RID)

var midID, ridID, rsid uint8
for _, extension := range parameters.HeaderExtensions {
switch extension.URI {
case sdp.SDESMidURI:
midID = uint8(extension.ID)
case sdp.SDESRTPStreamIDURI:
ridID = uint8(extension.ID)
case sdesRepairRTPStreamIDURI:
rsid = uint8(extension.ID)
}
}
assert.NotZero(t, midID)
assert.NotZero(t, ridID)
assert.NotZero(t, rsid)

err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string {
// Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6
re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$")
res := re.ReplaceAllString(sdp, "")
return res
})
assert.NoError(t, err)

// padding only packets should not affect simulcast probe
var sequenceNumber uint16
for sequenceNumber = 0; sequenceNumber < simulcastProbeCount+10; sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
Padding: true,
SSRC: uint32(i),
},
Payload: []byte{0x00, 0x02},
}

assert.NoError(t, track.WriteRTP(pkt))
}
}
assert.False(t, ridsFullfilled(), "Simulcast probe should not be fulfilled by padding only packets")

for ; !ridsFullfilled(); sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
SSRC: uint32(i),
},
Payload: []byte{0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

assertRidCorrect(t)

for i := 0; i < simulcastProbeCount+10; i++ {
sequenceNumber++
time.Sleep(10 * time.Millisecond)

for j, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 97,
SSRC: uint32(100 + j),
},
Payload: []byte{0x00, 0x00, 0x00, 0x00, 0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))
assert.NoError(t, pkt.Header.SetExtension(rsid, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

for ; rtxPacketRead.Load() == 0; sequenceNumber++ {
time.Sleep(20 * time.Millisecond)

for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} {
pkt := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: sequenceNumber,
PayloadType: 96,
SSRC: uint32(i),
},
Payload: []byte{0x00},
}
assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0")))
assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID())))

assert.NoError(t, track.WriteRTP(pkt))
}
}

closePairNow(t, pcOffer, pcAnswer)

wg.Wait()

assert.Greater(t, rtxPacketRead.Load(), int32(0), "no rtx packet read")
}

// Everytime we receieve a new SSRC we probe it and try to determine the proper way to handle it.
// In most cases a Track explicitly declares a SSRC and a OnTrack is fired. In two cases we don't
// know the SSRC ahead of time
Expand Down
4 changes: 4 additions & 0 deletions rtpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *intercep
for i := range r.tracks {
if r.tracks[i].track.RID() == rsid {
track = &r.tracks[i]
if track.track.RtxSSRC() == 0 {
track.track.setRtxSSRC(SSRC(streamInfo.SSRC))
}
break
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions track_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,9 @@ func (t *TrackRemote) HasRTX() bool {
defer t.mu.RUnlock()
return t.rtxSsrc != 0
}

func (t *TrackRemote) setRtxSSRC(ssrc SSRC) {
t.mu.Lock()
defer t.mu.Unlock()
t.rtxSsrc = ssrc
}

0 comments on commit e7cf3ba

Please sign in to comment.