diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index a5375fde428..d6f131fe0f0 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -13,8 +13,10 @@ import ( "errors" "fmt" "io" + "regexp" "strings" "sync" + "sync/atomic" "testing" "time" @@ -24,6 +26,7 @@ import ( "github.com/pion/rtp" "github.com/pion/sdp/v3" "github.com/pion/transport/v2/test" + "github.com/pion/webrtc/v3/internal/util" "github.com/pion/webrtc/v3/pkg/media" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1361,6 +1364,222 @@ func TestPeerConnection_Simulcast(t *testing.T) { }) } +type simulcastTestTrackLocal struct { + *TrackLocalStaticRTP +} + +// don't use ssrc&payload in bindings to let the test write different stream packets. +func (s *simulcastTestTrackLocal) WriteRTP(pkt *rtp.Packet) error { + packet := getPacketAllocationFromPool() + + defer resetPacketPoolAllocation(packet) + + *packet = *pkt + + s.mu.RLock() + defer s.mu.RUnlock() + + writeErrs := []error{} + + for _, b := range s.bindings { + if _, err := b.writeStream.WriteRTP(&packet.Header, packet.Payload); err != nil { + writeErrs = append(writeErrs, err) + } + } + + return util.FlattenErrs(writeErrs) +} + +func TestPeerConnection_Simulcast_RTX(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + report := test.CheckRoutines(t) + defer report() + + rids := []string{"a", "b"} + pcOffer, pcAnswer, err := newPair() + assert.NoError(t, err) + + vp8WriterAStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[0])) + assert.NoError(t, err) + + vp8WriterBStatic, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion2", WithRTPStreamID(rids[1])) + assert.NoError(t, err) + + vp8WriterA, vp8WriterB := &simulcastTestTrackLocal{vp8WriterAStatic}, &simulcastTestTrackLocal{vp8WriterBStatic} + + sender, err := pcOffer.AddTrack(vp8WriterA) + assert.NoError(t, err) + assert.NotNil(t, sender) + + assert.NoError(t, sender.AddEncoding(vp8WriterB)) + + var ridMapLock sync.RWMutex + ridMap := map[string]int{} + + assertRidCorrect := func(t *testing.T) { + ridMapLock.Lock() + defer ridMapLock.Unlock() + + for _, rid := range rids { + assert.Equal(t, ridMap[rid], 1) + } + assert.Equal(t, len(ridMap), 2) + } + + ridsFullfilled := func() bool { + ridMapLock.Lock() + defer ridMapLock.Unlock() + + ridCount := len(ridMap) + return ridCount == 2 + } + + var rtxPacketRead atomic.Int32 + var wg sync.WaitGroup + wg.Add(2) + + pcAnswer.OnTrack(func(trackRemote *TrackRemote, _ *RTPReceiver) { + ridMapLock.Lock() + ridMap[trackRemote.RID()] = ridMap[trackRemote.RID()] + 1 + ridMapLock.Unlock() + + defer wg.Done() + + for { + _, attr, rerr := trackRemote.ReadRTP() + if rerr != nil { + break + } + if pt, ok := attr.Get(AttributeRtxPayloadType).(byte); ok { + if pt == 97 { + rtxPacketRead.Add(1) + } + } + } + }) + + parameters := sender.GetParameters() + assert.Equal(t, "a", parameters.Encodings[0].RID) + assert.Equal(t, "b", parameters.Encodings[1].RID) + + var midID, ridID, rsid uint8 + for _, extension := range parameters.HeaderExtensions { + switch extension.URI { + case sdp.SDESMidURI: + midID = uint8(extension.ID) + case sdp.SDESRTPStreamIDURI: + ridID = uint8(extension.ID) + case sdesRepairRTPStreamIDURI: + rsid = uint8(extension.ID) + } + } + assert.NotZero(t, midID) + assert.NotZero(t, ridID) + assert.NotZero(t, rsid) + + err = signalPairWithModification(pcOffer, pcAnswer, func(sdp string) string { + // Original chrome sdp contains no ssrc info https://pastebin.com/raw/JTjX6zg6 + re := regexp.MustCompile("(?m)[\r\n]+^.*a=ssrc.*$") + res := re.ReplaceAllString(sdp, "") + return res + }) + assert.NoError(t, err) + + // padding only packets should not affect simulcast probe + var sequenceNumber uint16 + for sequenceNumber = 0; sequenceNumber < simulcastProbeCount+10; sequenceNumber++ { + time.Sleep(20 * time.Millisecond) + + for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: sequenceNumber, + PayloadType: 96, + Padding: true, + SSRC: uint32(i), + }, + Payload: []byte{0x00, 0x02}, + } + + assert.NoError(t, track.WriteRTP(pkt)) + } + } + assert.False(t, ridsFullfilled(), "Simulcast probe should not be fulfilled by padding only packets") + + for ; !ridsFullfilled(); sequenceNumber++ { + time.Sleep(20 * time.Millisecond) + + for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: sequenceNumber, + PayloadType: 96, + SSRC: uint32(i), + }, + Payload: []byte{0x00}, + } + assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0"))) + assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID()))) + + assert.NoError(t, track.WriteRTP(pkt)) + } + } + + assertRidCorrect(t) + + for i := 0; i < simulcastProbeCount+10; i++ { + sequenceNumber++ + time.Sleep(10 * time.Millisecond) + + for j, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: sequenceNumber, + PayloadType: 97, + SSRC: uint32(100 + j), + }, + Payload: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + } + assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0"))) + assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID()))) + assert.NoError(t, pkt.Header.SetExtension(rsid, []byte(track.RID()))) + + assert.NoError(t, track.WriteRTP(pkt)) + } + } + + for ; rtxPacketRead.Load() == 0; sequenceNumber++ { + time.Sleep(20 * time.Millisecond) + + for i, track := range []*simulcastTestTrackLocal{vp8WriterA, vp8WriterB} { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + SequenceNumber: sequenceNumber, + PayloadType: 96, + SSRC: uint32(i), + }, + Payload: []byte{0x00}, + } + assert.NoError(t, pkt.Header.SetExtension(midID, []byte("0"))) + assert.NoError(t, pkt.Header.SetExtension(ridID, []byte(track.RID()))) + + assert.NoError(t, track.WriteRTP(pkt)) + } + } + + closePairNow(t, pcOffer, pcAnswer) + + wg.Wait() + + assert.Greater(t, rtxPacketRead.Load(), int32(0), "no rtx packet read") +} + // Everytime we receieve a new SSRC we probe it and try to determine the proper way to handle it. // In most cases a Track explicitly declares a SSRC and a OnTrack is fired. In two cases we don't // know the SSRC ahead of time diff --git a/rtpreceiver.go b/rtpreceiver.go index 8481a09f27f..ec7c79bce2e 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -409,6 +409,10 @@ func (r *RTPReceiver) receiveForRtx(ssrc SSRC, rsid string, streamInfo *intercep for i := range r.tracks { if r.tracks[i].track.RID() == rsid { track = &r.tracks[i] + if track.track.RtxSSRC() == 0 { + track.track.setRtxSSRC(SSRC(streamInfo.SSRC)) + } + break } } } diff --git a/track_remote.go b/track_remote.go index dfdb12b9594..e2d8d70df2e 100644 --- a/track_remote.go +++ b/track_remote.go @@ -223,3 +223,9 @@ func (t *TrackRemote) HasRTX() bool { defer t.mu.RUnlock() return t.rtxSsrc != 0 } + +func (t *TrackRemote) setRtxSSRC(ssrc SSRC) { + t.mu.Lock() + defer t.mu.Unlock() + t.rtxSsrc = ssrc +}