diff --git a/interceptor_test.go b/interceptor_test.go index aad610bb35d..2670627df74 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -225,12 +225,59 @@ func Test_InterceptorRegistry_Build(t *testing.T) { }, }) - peerConnectionA, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) - assert.NoError(t, err) - - peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) + peerConnectionA, peerConnectionB, err := NewAPI(WithInterceptorRegistry(ir)).newPair(Configuration{}) assert.NoError(t, err) assert.Equal(t, 2, registryBuildCount) closePairNow(t, peerConnectionA, peerConnectionB) } + +func Test_Interceptor_ZeroSSRC(t *testing.T) { + to := test.TimeOut(time.Second * 20) + defer to.Stop() + + report := test.CheckRoutines(t) + defer report() + + track, err := NewTrackLocalStaticRTP(RTPCodecCapability{MimeType: MimeTypeVP8}, "video", "pion") + assert.NoError(t, err) + + offerer, answerer, err := newPair() + assert.NoError(t, err) + + _, err = offerer.AddTrack(track) + assert.NoError(t, err) + + probeReceiverCreated := make(chan struct{}) + + go func() { + sequenceNumber := uint16(0) + for range time.NewTicker(time.Millisecond * 20).C { + track.mu.Lock() + if len(track.bindings) == 1 { + _, err = track.bindings[0].writeStream.WriteRTP(&rtp.Header{ + Version: 2, + SSRC: 0, + SequenceNumber: sequenceNumber, + }, []byte{0, 1, 2, 3, 4, 5}) + assert.NoError(t, err) + } + sequenceNumber++ + track.mu.Unlock() + + if nonMediaBandwidthProbe, ok := answerer.nonMediaBandwidthProbe.Load().(*RTPReceiver); ok { + assert.Equal(t, len(nonMediaBandwidthProbe.Tracks()), 1) + close(probeReceiverCreated) + return + } + } + }() + + assert.NoError(t, signalPair(offerer, answerer)) + + peerConnectionConnected := untilConnectionState(PeerConnectionStateConnected, offerer, answerer) + peerConnectionConnected.Wait() + + <-probeReceiverCreated + closePairNow(t, offerer, answerer) +} diff --git a/peerconnection.go b/peerconnection.go index 67af02bc639..0f7e78d8796 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -68,7 +68,8 @@ type PeerConnection struct { // should be defined (see JSEP 3.4.1). greaterMid int - rtpTransceivers []*RTPTransceiver + rtpTransceivers []*RTPTransceiver + nonMediaBandwidthProbe atomic.Value // RTPReceiver onSignalingStateChangeHandler func(SignalingState) onICEConnectionStateChangeHandler atomic.Value // func(ICEConnectionState) @@ -1524,6 +1525,40 @@ func (pc *PeerConnection) handleUndeclaredSSRC(ssrc SSRC, remoteDescription *Ses return true, nil } +// Chrome sends probing traffic on SSRC 0. This reads the packets to ensure that we properly +// generate TWCC reports for it. Since this isn't actually media we don't pass this to the user +func (pc *PeerConnection) handleNonMediaBandwidthProbe() { + nonMediaBandwidthProbe, err := pc.api.NewRTPReceiver(RTPCodecTypeVideo, pc.dtlsTransport) + if err != nil { + pc.log.Errorf("handleNonMediaBandwidthProbe failed to create RTPReceiver: %v", err) + return + } + + if err = nonMediaBandwidthProbe.Receive(RTPReceiveParameters{ + Encodings: []RTPDecodingParameters{ + { + RTPCodingParameters: RTPCodingParameters{ + SSRC: 0, + PayloadType: 97, + }, + }, + }, + }); err != nil { + pc.log.Errorf("handleNonMediaBandwidthProbe failed to start RTPReceiver: %v", err) + return + } + + pc.nonMediaBandwidthProbe.Store(nonMediaBandwidthProbe) + b := make([]byte, pc.api.settingEngine.getReceiveMTU()) + for { + _, _, err = nonMediaBandwidthProbe.readRTP(b, nonMediaBandwidthProbe.Track()) + if err != nil { + pc.log.Tracef("handleNonMediaBandwidthProbe read exiting: %v", err) + return + } + } +} + func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) error { //nolint:gocognit remoteDescription := pc.RemoteDescription() if remoteDescription == nil { @@ -1656,6 +1691,11 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { continue } + if ssrc == 0 { + go pc.handleNonMediaBandwidthProbe() + continue + } + pc.dtlsTransport.storeSimulcastStream(stream) if atomic.AddUint64(&simulcastRoutineCount, 1) >= simulcastMaxProbeRoutines { @@ -2072,6 +2112,9 @@ func (pc *PeerConnection) Close() error { closeErrs = append(closeErrs, t.Stop()) } } + if nonMediaBandwidthProbe, ok := pc.nonMediaBandwidthProbe.Load().(*RTPReceiver); ok { + closeErrs = append(closeErrs, nonMediaBandwidthProbe.Stop()) + } pc.mu.Unlock() // https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #5) diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 0e29d90f400..121b41c68a1 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -1528,7 +1528,7 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { SequenceNumber: sequenceNumber, PayloadType: 96, Padding: true, - SSRC: uint32(i), + SSRC: uint32(i + 1), }, Payload: []byte{0x00, 0x02}, } @@ -1547,7 +1547,7 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { Version: 2, SequenceNumber: sequenceNumber, PayloadType: 96, - SSRC: uint32(i), + SSRC: uint32(i + 1), }, Payload: []byte{0x00}, } @@ -1591,7 +1591,7 @@ func TestPeerConnection_Simulcast_RTX(t *testing.T) { Version: 2, SequenceNumber: sequenceNumber, PayloadType: 96, - SSRC: uint32(i), + SSRC: uint32(i + 1), }, Payload: []byte{0x00}, } diff --git a/peerconnection_renegotiation_test.go b/peerconnection_renegotiation_test.go index 33bc714f774..6f4465475f2 100644 --- a/peerconnection_renegotiation_test.go +++ b/peerconnection_renegotiation_test.go @@ -1046,7 +1046,7 @@ func TestPeerConnection_Renegotiation_Simulcast(t *testing.T) { for ssrc, rid := range rids { header := &rtp.Header{ Version: 2, - SSRC: uint32(ssrc), + SSRC: uint32(ssrc + 1), SequenceNumber: sequenceNumber, PayloadType: 96, } diff --git a/rtpreceiver.go b/rtpreceiver.go index 3dd4d83b1aa..abd2dff969b 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -201,7 +201,7 @@ func (r *RTPReceiver) startReceive(parameters RTPReceiveParameters) error { var t *trackStreams for idx, ts := range r.tracks { - if ts.track != nil && parameters.Encodings[i].SSRC != 0 && ts.track.SSRC() == parameters.Encodings[i].SSRC { + if ts.track != nil && ts.track.SSRC() == parameters.Encodings[i].SSRC { t = &r.tracks[idx] break } @@ -210,12 +210,10 @@ func (r *RTPReceiver) startReceive(parameters RTPReceiveParameters) error { return fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, parameters.Encodings[i].SSRC) } - if parameters.Encodings[i].SSRC != 0 { - t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, codec, globalParams.HeaderExtensions) - var err error - if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil { - return err - } + t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, codec, globalParams.HeaderExtensions) + var err error + if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil { + return err } if rtxSsrc := parameters.Encodings[i].RTX.SSRC; rtxSsrc != 0 {