diff --git a/dtlstransport.go b/dtlstransport.go index ec08a0846d3..bfc38d2b724 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -51,7 +51,7 @@ type DTLSTransport struct { srtpSession, srtcpSession atomic.Value srtpEndpoint, srtcpEndpoint *mux.Endpoint - simulcastStreams []*srtp.ReadStreamSRTP + simulcastStreams []simulcastStreamPair srtpReady chan struct{} dtlsMatcher mux.MatchFunc @@ -60,6 +60,11 @@ type DTLSTransport struct { log logging.LeveledLogger } +type simulcastStreamPair struct { + srtp *srtp.ReadStreamSRTP + srtcp *srtp.ReadStreamSRTCP +} + // NewDTLSTransport creates a new DTLSTransport. // This constructor is part of the ORTC API. It is not // meant to be used together with the basic WebRTC API. @@ -436,7 +441,8 @@ func (t *DTLSTransport) Stop() error { } for i := range t.simulcastStreams { - closeErrs = append(closeErrs, t.simulcastStreams[i].Close()) + closeErrs = append(closeErrs, t.simulcastStreams[i].srtp.Close()) + closeErrs = append(closeErrs, t.simulcastStreams[i].srtcp.Close()) } if t.conn != nil { @@ -477,11 +483,11 @@ func (t *DTLSTransport) ensureICEConn() error { return nil } -func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) { +func (t *DTLSTransport) storeSimulcastStream(srtpReadStream *srtp.ReadStreamSRTP, srtcpReadStream *srtp.ReadStreamSRTCP) { t.lock.Lock() defer t.lock.Unlock() - t.simulcastStreams = append(t.simulcastStreams, s) + t.simulcastStreams = append(t.simulcastStreams, simulcastStreamPair{srtpReadStream, srtcpReadStream}) } func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { diff --git a/peerconnection.go b/peerconnection.go index 5876e5f76cd..2b8a9038539 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1670,26 +1670,42 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { return } - stream, ssrc, err := srtpSession.AcceptStream() + srtcpSession, err := pc.dtlsTransport.getSRTCPSession() + if err != nil { + pc.log.Warnf("undeclaredMediaProcessor failed to open SrtcpSession: %v", err) + return + } + + srtpReadStream, ssrc, err := srtpSession.AcceptStream() if err != nil { pc.log.Warnf("Failed to accept RTP %v", err) return } + // open accompanying srtcp stream + srtcpReadStream, err := srtcpSession.OpenReadStream(ssrc) + if err != nil { + pc.log.Warnf("Failed to open RTCP stream for %d: %v", ssrc, err) + return + } + if pc.isClosed.get() { - if err = stream.Close(); err != nil { + if err = srtpReadStream.Close(); err != nil { pc.log.Warnf("Failed to close RTP stream %v", err) } + if err = srtcpReadStream.Close(); err != nil { + pc.log.Warnf("Failed to close RTCP stream %v", err) + } continue } + pc.dtlsTransport.storeSimulcastStream(srtpReadStream, srtcpReadStream) + if ssrc == 0 { go pc.handleNonMediaBandwidthProbe() continue } - pc.dtlsTransport.storeSimulcastStream(stream) - if atomic.AddUint64(&simulcastRoutineCount, 1) >= simulcastMaxProbeRoutines { atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) pc.log.Warn(ErrSimulcastProbeOverflow.Error()) @@ -1701,7 +1717,7 @@ func (pc *PeerConnection) undeclaredRTPMediaProcessor() { pc.log.Errorf(incomingUnhandledRTPSsrc, ssrc, err) } atomic.AddUint64(&simulcastRoutineCount, ^uint64(0)) - }(stream, SSRC(ssrc)) + }(srtpReadStream, SSRC(ssrc)) } }