diff --git a/track_local_static.go b/track_local_static.go index d4ca1ed5c06..2dd665cd240 100644 --- a/track_local_static.go +++ b/track_local_static.go @@ -31,6 +31,7 @@ type TrackLocalStaticRTP struct { mu sync.RWMutex bindings []trackBinding codec RTPCodecCapability + payloader func(RTPCodecCapability) (rtp.Payloader, error) id, rid, streamID string } @@ -57,6 +58,13 @@ func WithRTPStreamID(rid string) func(*TrackLocalStaticRTP) { } } +// WithPayloader allows the user to override the Payloader +func WithPayloader(h func(RTPCodecCapability) (rtp.Payloader, error)) func(*TrackLocalStaticRTP) { + return func(s *TrackLocalStaticRTP) { + s.payloader = h + } +} + // Bind is called by the PeerConnection after negotiation is complete // This asserts that the code requested is supported by the remote peer. // If so it sets up all the state (SSRC and PayloadType) to have a call @@ -250,7 +258,12 @@ func (s *TrackLocalStaticSample) Bind(t TrackLocalContext) (RTPCodecParameters, return codec, nil } - payloader, err := payloaderForCodec(codec.RTPCodecCapability) + payloadHandler := s.rtpTrack.payloader + if payloadHandler == nil { + payloadHandler = payloaderForCodec + } + + payloader, err := payloadHandler(codec.RTPCodecCapability) if err != nil { return codec, err } diff --git a/track_local_static_test.go b/track_local_static_test.go index b3c3c076612..db4e1855544 100644 --- a/track_local_static_test.go +++ b/track_local_static_test.go @@ -9,12 +9,14 @@ package webrtc import ( "context" "errors" + "sync/atomic" "testing" "time" "github.com/pion/rtp" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // If a remote doesn't support a Codec used by a `TrackLocalStatic` @@ -336,3 +338,49 @@ func Test_TrackLocalStatic_RTX(t *testing.T) { closePairNow(t, offerer, answerer) } + +type customCodecPayloader struct { + invokeCount atomic.Int32 +} + +func (c *customCodecPayloader) Payload(_ uint16, payload []byte) [][]byte { + c.invokeCount.Add(1) + return [][]byte{payload} +} + +func Test_TrackLocalStatic_Payloader(t *testing.T) { + const mimeTypeCustomCodec = "video/custom-codec" + + mediaEngine := &MediaEngine{} + assert.NoError(t, mediaEngine.RegisterCodec(RTPCodecParameters{ + RTPCodecCapability: RTPCodecCapability{MimeType: mimeTypeCustomCodec, ClockRate: 90000, Channels: 0, SDPFmtpLine: "", RTCPFeedback: nil}, + PayloadType: 96, + }, RTPCodecTypeVideo)) + + offerer, err := NewAPI(WithMediaEngine(mediaEngine)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + answerer, err := NewAPI(WithMediaEngine(mediaEngine)).NewPeerConnection(Configuration{}) + assert.NoError(t, err) + + customPayloader := &customCodecPayloader{} + track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: mimeTypeCustomCodec}, "video", "pion", WithPayloader(func(c RTPCodecCapability) (rtp.Payloader, error) { + require.Equal(t, c.MimeType, mimeTypeCustomCodec) + return customPayloader, nil + })) + assert.NoError(t, err) + + _, err = offerer.AddTrack(track) + assert.NoError(t, err) + + assert.NoError(t, signalPair(offerer, answerer)) + + onTrackFired, onTrackFiredFunc := context.WithCancel(context.Background()) + answerer.OnTrack(func(*TrackRemote, *RTPReceiver) { + onTrackFiredFunc() + }) + + sendVideoUntilDone(onTrackFired.Done(), t, []*TrackLocalStaticSample{track}) + + closePairNow(t, offerer, answerer) +}