diff --git a/webrtc/src/api/mod.rs b/webrtc/src/api/mod.rs index 2252dfbb2..5314aa830 100644 --- a/webrtc/src/api/mod.rs +++ b/webrtc/src/api/mod.rs @@ -166,6 +166,7 @@ impl API { Arc::clone(&self.media_engine), interceptor, false, + self.setting_engine.enable_sender_rtx, ) .await } diff --git a/webrtc/src/api/setting_engine/mod.rs b/webrtc/src/api/setting_engine/mod.rs index 3387d9d6e..17170a706 100644 --- a/webrtc/src/api/setting_engine/mod.rs +++ b/webrtc/src/api/setting_engine/mod.rs @@ -78,6 +78,7 @@ pub struct SettingEngine { pub(crate) srtp_protection_profiles: Vec, pub(crate) receive_mtu: usize, pub(crate) mid_generator: Option String + Send + Sync>>, + pub(crate) enable_sender_rtx: bool, } impl SettingEngine { @@ -334,4 +335,11 @@ impl SettingEngine { pub fn set_mid_generator(&mut self, f: impl Fn(isize) -> String + Send + Sync + 'static) { self.mid_generator = Some(Arc::new(f)); } + + /// enable_sender_rtx allows outgoing rtx streams to be created where applicable. + /// RTPSender will create an RTP retransmission stream for each source stream where a retransmission + /// codec is configured. + pub fn enable_sender_rtx(&mut self, is_enabled: bool) { + self.enable_sender_rtx = is_enabled; + } } diff --git a/webrtc/src/peer_connection/mod.rs b/webrtc/src/peer_connection/mod.rs index 9f99e18b2..75db0a860 100644 --- a/webrtc/src/peer_connection/mod.rs +++ b/webrtc/src/peer_connection/mod.rs @@ -1404,6 +1404,7 @@ impl RTCPeerConnection { }; let receive_mtu = self.internal.setting_engine.get_receive_mtu(); + let enable_sender_rtx = self.internal.setting_engine.enable_sender_rtx; let receiver = Arc::new(RTCRtpReceiver::new( receive_mtu, @@ -1422,6 +1423,7 @@ impl RTCPeerConnection { Arc::clone(&self.internal.media_engine), Arc::clone(&self.interceptor), false, + enable_sender_rtx, ) .await, ); diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index 85c73e2f3..5c46bb05c 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -535,6 +535,7 @@ impl PeerConnectionInternal { Arc::clone(&self.media_engine), interceptor, false, + self.setting_engine.enable_sender_rtx, ) .await, ); @@ -589,6 +590,7 @@ impl PeerConnectionInternal { Arc::clone(&self.media_engine), Arc::clone(&interceptor), false, + self.setting_engine.enable_sender_rtx, ) .await, ); @@ -1387,7 +1389,7 @@ impl PeerConnectionInternal { let sender = transceiver.sender().await; let track_encodings = sender.track_encodings.lock().await; for encoding in track_encodings.iter() { - let track_id = encoding.track.id().to_string(); + let track_id = encoding.track.id(); let kind = match encoding.track.kind() { RTPCodecType::Unspecified => continue, RTPCodecType::Audio => "audio", @@ -1395,12 +1397,22 @@ impl PeerConnectionInternal { }; track_infos.push(TrackInfo { - track_id, + track_id: track_id.to_owned(), ssrc: encoding.ssrc, mid: mid.to_owned(), rid: encoding.track.rid().map(Into::into), kind, }); + + if let Some(rtx) = &encoding.rtx { + track_infos.push(TrackInfo { + track_id: track_id.to_owned(), + ssrc: rtx.ssrc, + mid: mid.to_owned(), + rid: encoding.track.rid().map(Into::into), + kind, + }); + } } } diff --git a/webrtc/src/peer_connection/sdp/mod.rs b/webrtc/src/peer_connection/sdp/mod.rs index 34983cc59..b3c6012a1 100644 --- a/webrtc/src/peer_connection/sdp/mod.rs +++ b/webrtc/src/peer_connection/sdp/mod.rs @@ -595,6 +595,23 @@ pub(crate) async fn add_transceiver_sdp( track.stream_id().to_owned(), /* streamLabel */ track.id().to_owned(), ); + + if encoding.rtx.ssrc != 0 { + media = media.with_media_source( + encoding.rtx.ssrc, + track.stream_id().to_owned(), + track.stream_id().to_owned(), + track.id().to_owned(), + ); + + media = media.with_value_attribute( + ATTR_KEY_SSRCGROUP.to_owned(), + format!( + "{} {} {}", + SEMANTIC_TOKEN_FLOW_IDENTIFICATION, encoding.ssrc, encoding.rtx.ssrc + ), + ); + } } if send_parameters.encodings.len() > 1 { diff --git a/webrtc/src/peer_connection/sdp/sdp_test.rs b/webrtc/src/peer_connection/sdp/sdp_test.rs index 9fc3cf9f8..4b86d30e5 100644 --- a/webrtc/src/peer_connection/sdp/sdp_test.rs +++ b/webrtc/src/peer_connection/sdp/sdp_test.rs @@ -701,6 +701,7 @@ async fn test_media_description_fingerprints() -> Result<()> { Arc::clone(&api.media_engine), Arc::clone(&interceptor), false, + false, ) .await, )) @@ -1148,6 +1149,161 @@ async fn test_populate_sdp() -> Result<()> { assert_eq!(offer_sdp.attribute(ATTR_KEY_GROUP), None); } + // "Sender RTX" + { + let mut se = SettingEngine::default(); + se.enable_sender_rtx(true); + + let mut me = MediaEngine::default(); + me.register_default_codecs()?; + + me.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/rtx".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "apt=96".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 97, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + me.push_codecs(me.video_codecs.clone(), RTPCodecType::Video) + .await; + + let api = APIBuilder::new() + .with_media_engine(me) + .with_setting_engine(se.clone()) + .build(); + let interceptor = api.interceptor_registry.build("")?; + let transport = Arc::new(RTCDtlsTransport::default()); + let receiver = Arc::new(api.new_rtp_receiver( + RTPCodecType::Video, + Arc::clone(&transport), + Arc::clone(&interceptor), + )); + + let codec = RTCRtpCodecCapability { + mime_type: "video/vp8".to_owned(), + ..Default::default() + }; + + let track = Arc::new(TrackLocalStaticSample::new_with_rid( + codec.clone(), + "video".to_owned(), + "f".to_owned(), + "webrtc-rs".to_owned(), + )); + + let sender = Arc::new( + api.new_rtp_sender( + Some(track), + Arc::clone(&transport), + Arc::clone(&interceptor), + ) + .await, + ); + + sender + .add_encoding(Arc::new(TrackLocalStaticSample::new_with_rid( + codec.clone(), + "video".to_owned(), + "h".to_owned(), + "webrtc-rs".to_owned(), + ))) + .await?; + + let tr = RTCRtpTransceiver::new( + receiver, + sender, + RTCRtpTransceiverDirection::Sendonly, + RTPCodecType::Video, + api.media_engine.video_codecs.clone(), + Arc::clone(&api.media_engine), + None, + ) + .await; + + let media_sections = vec![MediaSection { + id: "video".to_owned(), + transceivers: vec![tr], + data: false, + ..Default::default() + }]; + + let d = SessionDescription::default(); + + let params = PopulateSdpParams { + media_description_fingerprint: se.sdp_media_level_fingerprints, + is_icelite: se.candidates.ice_lite, + extmap_allow_mixed: true, + connection_role: DEFAULT_DTLS_ROLE_OFFER.to_connection_role(), + ice_gathering_state: RTCIceGatheringState::Complete, + match_bundle_group: None, + }; + let offer_sdp = populate_sdp( + d, + &[], + &api.media_engine, + &[], + &RTCIceParameters::default(), + &media_sections, + params, + ) + .await?; + + // Test codecs and FID groups + let mut found_vp8 = false; + let mut found_rtx = false; + let mut found_ssrcs: Vec<&str> = Vec::new(); + let mut found_fids = Vec::new(); + for desc in &offer_sdp.media_descriptions { + if desc.media_name.media != "video" { + continue; + } + for a in &desc.attributes { + if a.key.contains("rtpmap") { + if let Some(value) = &a.value { + if value == "96 VP8/90000" { + found_vp8 = true; + } else if value == "97 rtx/90000" { + found_rtx = true; + } + } + } else if a.key == "ssrc" { + if let Some((ssrc, _)) = a.value.as_ref().and_then(|v| v.split_once(' ')) { + found_ssrcs.push(ssrc); + } + } else if a.key == "ssrc-group" { + if let Some(group) = a.value.as_ref().and_then(|v| v.strip_prefix("FID ")) { + let Some((a, b)) = group.split_once(" ") else { + panic!("invalid FID format in sdp") + }; + + found_fids.extend([a, b]); + } + } + } + } + + found_fids.sort(); + + found_ssrcs.sort(); + // the sdp may have multiple attributes for each ssrc + found_ssrcs.dedup(); + + assert!(found_vp8, "vp8 should be present in sdp"); + assert!(found_rtx, "rtx should be present in sdp"); + assert_eq!(found_ssrcs.len(), 4, "all ssrcs should be present in sdp"); + assert_eq!(found_fids.len(), 4, "all fids should be present in sdp"); + + assert_eq!(found_ssrcs, found_fids); + } + Ok(()) } diff --git a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs b/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs index 7520667db..451b6e2cf 100644 --- a/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_receiver/rtp_receiver_test.rs @@ -12,7 +12,7 @@ use crate::peer_connection::peer_connection_test::{ close_pair_now, create_vnet_pair, signal_pair, until_connection_state, }; use crate::rtp_transceiver::rtp_codec::RTCRtpHeaderExtensionParameters; -use crate::rtp_transceiver::RTCPFeedback; +use crate::rtp_transceiver::{RTCPFeedback, RTCRtpCodecCapability}; use crate::track::track_local::track_local_static_sample::TrackLocalStaticSample; use crate::track::track_local::TrackLocal; diff --git a/webrtc/src/rtp_transceiver/rtp_sender/mod.rs b/webrtc/src/rtp_transceiver/rtp_sender/mod.rs index eb4c32c2a..dff4ed8db 100644 --- a/webrtc/src/rtp_transceiver/rtp_sender/mod.rs +++ b/webrtc/src/rtp_transceiver/rtp_sender/mod.rs @@ -5,17 +5,19 @@ use std::sync::atomic::Ordering; use std::sync::{Arc, Weak}; use ice::rand::generate_crypto_random_string; -use interceptor::stream_info::StreamInfo; +use interceptor::stream_info::{AssociatedStreamInfo, StreamInfo}; use interceptor::{Attributes, Interceptor, RTCPReader, RTPWriter}; use portable_atomic::AtomicBool; +use tokio::select; use tokio::sync::{watch, Mutex, Notify}; use util::sync::Mutex as SyncMutex; use super::srtp_writer_future::SequenceTransformer; +use super::RTCRtpRtxParameters; use crate::api::media_engine::MediaEngine; use crate::dtls_transport::RTCDtlsTransport; use crate::error::{Error, Result}; -use crate::rtp_transceiver::rtp_codec::RTPCodecType; +use crate::rtp_transceiver::rtp_codec::{codec_rtx_search, RTPCodecType}; use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; use crate::rtp_transceiver::srtp_writer_future::SrtpWriterFuture; use crate::rtp_transceiver::{ @@ -39,6 +41,16 @@ pub(crate) struct TrackEncoding { pub(crate) context: Mutex, pub(crate) ssrc: SSRC, + + pub(crate) rtx: Option, +} + +pub(crate) struct RtxEncoding { + pub(crate) srtp_stream: Arc, + pub(crate) rtcp_interceptor: Arc, + pub(crate) stream_info: Mutex, + + pub(crate) ssrc: SSRC, } /// RTPSender allows an application to control how a given Track is encoded and transmitted to a remote peer @@ -54,12 +66,14 @@ pub struct RTCRtpSender { pub(crate) track_encodings: Mutex>, seq_trans: Arc, + rtx_seq_trans: Arc, pub(crate) transport: Arc, pub(crate) kind: RTPCodecType, pub(crate) payload_type: PayloadType, receive_mtu: usize, + enable_rtx: bool, /// a transceiver sender since we can just check the /// transceiver negotiation status @@ -104,6 +118,7 @@ impl RTCRtpSender { media_engine: Arc, interceptor: Arc, start_paused: bool, + enable_rtx: bool, ) -> Self { let id = generate_crypto_random_string( 32, @@ -120,6 +135,7 @@ impl RTCRtpSender { }); let seq_trans = Arc::new(SequenceTransformer::new()); + let rtx_seq_trans = Arc::new(SequenceTransformer::new()); let stream_ids = track .as_ref() @@ -129,12 +145,14 @@ impl RTCRtpSender { track_encodings: Mutex::new(vec![]), seq_trans, + rtx_seq_trans, transport, kind, payload_type: 0, receive_mtu, + enable_rtx, negotiated: AtomicBool::new(false), @@ -222,6 +240,38 @@ impl RTCRtpSender { let srtp_rtcp_reader = Arc::clone(&srtp_stream) as Arc; let rtcp_interceptor = self.interceptor.bind_rtcp_reader(srtp_rtcp_reader).await; + let create_rtx_stream = self.enable_rtx && self + .media_engine + .get_codecs_by_kind(track.kind()) + .iter() + .any(|codec| matches!(codec.capability.mime_type.split_once("/"), Some((_, "rtx")))); + + let rtx = if create_rtx_stream { + let ssrc = rand::random::(); + + let srtp_stream = Arc::new(SrtpWriterFuture { + closed: AtomicBool::new(false), + ssrc, + rtp_sender: Arc::downgrade(&self.internal), + rtp_transport: Arc::clone(&self.transport), + rtcp_read_stream: Mutex::new(None), + rtp_write_session: Mutex::new(None), + seq_trans: Arc::clone(&self.rtx_seq_trans), + }); + + let srtp_rtcp_reader = Arc::clone(&srtp_stream) as Arc; + let rtcp_interceptor = self.interceptor.bind_rtcp_reader(srtp_rtcp_reader).await; + + Some(RtxEncoding { + srtp_stream, + rtcp_interceptor, + stream_info: Mutex::new(StreamInfo::default()), + ssrc, + }) + } else { + None + }; + let encoding = TrackEncoding { track, srtp_stream, @@ -229,6 +279,7 @@ impl RTCRtpSender { stream_info: Mutex::new(StreamInfo::default()), context: Mutex::new(TrackLocalContext::default()), ssrc, + rtx, }; track_encodings.push(encoding); @@ -273,7 +324,9 @@ impl RTCRtpSender { rid: e.track.rid().unwrap_or_default().into(), ssrc: e.ssrc, payload_type: self.payload_type, - ..Default::default() + rtx: RTCRtpRtxParameters { + ssrc: e.rtx.as_ref().map(|e| e.ssrc).unwrap_or_default(), + }, }); } @@ -340,6 +393,7 @@ impl RTCRtpSender { } self.seq_trans.reset_offset(); + self.rtx_seq_trans.reset_offset(); let mid = self .rtp_transceiver @@ -429,7 +483,7 @@ impl RTCRtpSender { ¶meters.rtp_parameters.header_extensions, None, ); - context.params.codecs = vec![codec]; + context.params.codecs = vec![codec.clone()]; let srtp_writer = Arc::clone(&encoding.srtp_stream) as Arc; let rtp_writer = self @@ -440,12 +494,66 @@ impl RTCRtpSender { *encoding.context.lock().await = context; *encoding.stream_info.lock().await = stream_info; *write_stream.interceptor_rtp_writer.lock().await = Some(rtp_writer); + + if let (Some(rtx), Some(rtx_codec)) = ( + &encoding.rtx, + codec_rtx_search(&codec, ¶meters.rtp_parameters.codecs), + ) { + let rtx_info = AssociatedStreamInfo { + ssrc: parameters.encodings[idx].ssrc, + payload_type: codec.payload_type, + }; + + let rtx_stream_info = create_stream_info( + self.id.clone(), + parameters.encodings[idx].rtx.ssrc, + rtx_codec.payload_type, + rtx_codec.capability.clone(), + ¶meters.rtp_parameters.header_extensions, + Some(rtx_info), + ); + + let rtx_srtp_writer = + Arc::clone(&rtx.srtp_stream) as Arc; + // ignore the rtp writer, only interceptors can write to the stream + self.interceptor + .bind_local_stream(&rtx_stream_info, rtx_srtp_writer) + .await; + + *rtx.stream_info.lock().await = rtx_stream_info; + + self.receive_rtcp_for_rtx(rtx.rtcp_interceptor.clone()); + } } self.send_called.send_replace(true); Ok(()) } + /// starts a routine that reads the rtx rtcp stream + /// These packets aren't exposed to the user, but we need to process them + /// for TWCC + fn receive_rtcp_for_rtx(&self, rtcp_reader: Arc) { + let receive_mtu = self.receive_mtu; + let stop_called_signal = self.internal.stop_called_signal.clone(); + let stop_called_rx = self.internal.stop_called_rx.clone(); + + tokio::spawn(async move { + let attrs = Attributes::new(); + let mut b = vec![0u8; receive_mtu]; + while !stop_called_signal.load(Ordering::SeqCst) { + select! { + r = rtcp_reader.read(&mut b, &attrs) => { + if r.is_err() { + break + } + }, + _ = stop_called_rx.notified() => break, + } + } + }); + } + /// stop irreversibly stops the RTPSender pub async fn stop(&self) -> Result<()> { if self.stop_called_signal.load(Ordering::SeqCst) { @@ -466,6 +574,13 @@ impl RTCRtpSender { self.interceptor.unbind_local_stream(&stream_info).await; encoding.srtp_stream.close().await?; + + if let Some(rtx) = &encoding.rtx { + let rtx_stream_info = rtx.stream_info.lock().await; + self.interceptor.unbind_local_stream(&rtx_stream_info).await; + + rtx.srtp_stream.close().await?; + } } Ok(()) @@ -544,7 +659,8 @@ impl RTCRtpSender { /// Errors if this [`RTCRtpSender`] has started to send data or sequence /// transforming has been already enabled. pub fn enable_seq_transformer(&self) -> Result<()> { - self.seq_trans.enable() + self.seq_trans.enable()?; + self.rtx_seq_trans.enable() } /// Will asynchronously block/wait until send() has been called diff --git a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs index cdae29c4f..a947e888c 100644 --- a/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs +++ b/webrtc/src/rtp_transceiver/rtp_sender/rtp_sender_test.rs @@ -1,4 +1,7 @@ +use async_trait::async_trait; use bytes::Bytes; +use interceptor::registry::Registry; +use interceptor::InterceptorBuilder; use portable_atomic::AtomicU64; use std::sync::atomic::Ordering; use std::sync::Arc; @@ -623,3 +626,173 @@ async fn test_rtp_sender_add_encoding() -> Result<()> { close_pair_now(&sender, &receiver).await; Ok(()) } + +#[derive(Debug)] +enum TestInterceptorEvent { + BindLocal(StreamInfo), + BindRemote(StreamInfo), + UnbindLocal(StreamInfo), + UnbindRemote(StreamInfo), +} + +#[derive(Clone)] +struct TestInterceptor(mpsc::UnboundedSender); + +#[async_trait] +impl Interceptor for TestInterceptor { + async fn bind_rtcp_reader( + &self, + reader: Arc, + ) -> Arc { + reader + } + + async fn bind_rtcp_writer( + &self, + writer: Arc, + ) -> Arc { + writer + } + + async fn bind_local_stream( + &self, + info: &StreamInfo, + writer: Arc, + ) -> Arc { + let _ = self.0.send(TestInterceptorEvent::BindLocal(info.clone())); + writer + } + + async fn unbind_local_stream(&self, info: &StreamInfo) { + let _ = self.0.send(TestInterceptorEvent::UnbindLocal(info.clone())); + } + + async fn bind_remote_stream( + &self, + info: &StreamInfo, + reader: Arc, + ) -> Arc { + let _ = self.0.send(TestInterceptorEvent::BindRemote(info.clone())); + reader + } + + async fn unbind_remote_stream(&self, info: &StreamInfo) { + let _ = self + .0 + .send(TestInterceptorEvent::UnbindRemote(info.clone())); + } + + async fn close(&self) -> std::result::Result<(), interceptor::Error> { + Ok(()) + } +} + +impl InterceptorBuilder for TestInterceptor { + fn build( + &self, + _id: &str, + ) -> std::result::Result, interceptor::Error> { + Ok(Arc::new(self.clone())) + } +} + +#[tokio::test] +async fn test_rtp_sender_rtx() -> Result<()> { + let mut s = SettingEngine::default(); + s.enable_sender_rtx(true); + + let (interceptor_tx, mut interceptor_rx) = mpsc::unbounded_channel(); + + let mut registry = Registry::new(); + registry.add(Box::new(TestInterceptor(interceptor_tx))); + + let mut m = MediaEngine::default(); + m.register_default_codecs()?; + // only register rtx for VP8 + m.register_codec( + RTCRtpCodecParameters { + capability: RTCRtpCodecCapability { + mime_type: "video/rtx".to_owned(), + clock_rate: 90000, + channels: 0, + sdp_fmtp_line: "apt=96".to_string(), + rtcp_feedback: vec![], + }, + payload_type: 97, + ..Default::default() + }, + RTPCodecType::Video, + )?; + + let api = APIBuilder::new() + .with_setting_engine(s) + .with_media_engine(m) + .with_interceptor_registry(registry) + .build(); + + let (mut offerer, mut answerer) = new_pair(&api).await?; + + let track_a = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_VP8.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let track_b = Arc::new(TrackLocalStaticSample::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_H264.to_owned(), + ..Default::default() + }, + "video".to_owned(), + "webrtc-rs".to_owned(), + )); + + let rtp_sender_a = offerer + .add_track(Arc::clone(&track_a) as Arc) + .await?; + + let rtp_sender_b = offerer + .add_track(Arc::clone(&track_b) as Arc) + .await?; + + signal_pair(&mut offerer, &mut answerer).await?; + + // rtx enabled for vp8 + assert!(rtp_sender_a.track().await.is_some()); + assert!(rtp_sender_a.track_encodings.lock().await[0].rtx.is_some()); + + // no rtx for h264 + assert!(rtp_sender_b.track().await.is_some()); + assert!(rtp_sender_b.track_encodings.lock().await[0].rtx.is_some()); + + close_pair_now(&offerer, &answerer).await; + + let mut vp8_ssrcs = Vec::new(); + let mut h264_ssrcs = Vec::new(); + let mut rtx_associated_ssrcs = Vec::new(); + + // pair closed, all interceptor events should be buffered + while let Ok(event) = interceptor_rx.try_recv() { + if let TestInterceptorEvent::BindLocal(info) = event { + match info.mime_type.as_str() { + MIME_TYPE_VP8 => vp8_ssrcs.push(info.ssrc), + MIME_TYPE_H264 => h264_ssrcs.push(info.ssrc), + "video/rtx" => rtx_associated_ssrcs.push( + info.associated_stream + .expect("rtx without asscoiated stream") + .ssrc, + ), + mime => panic!("unexpected mime: {mime}"), + } + } + } + + assert_eq!(vp8_ssrcs.len(), 1); + assert_eq!(h264_ssrcs.len(), 1); + assert_eq!(rtx_associated_ssrcs, vp8_ssrcs); + + Ok(()) +}