diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 03b084d..1d5745f 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -4,6 +4,7 @@ package nack import ( + "encoding/binary" "sync" "github.com/pion/interceptor" @@ -62,6 +63,11 @@ type ResponderInterceptor struct { type localStream struct { sendBuffer *sendBuffer rtpWriter interceptor.RTPWriter + + // Non-zero if Retransmissions should be sent on a distinct stream + rtxSsrc uint32 + rtxPayloadType uint8 + rtxSequencer rtp.Sequencer } // NewResponderInterceptor returns a new ResponderInterceptorFactor @@ -108,7 +114,13 @@ func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, wri // error is already checked in NewGeneratorInterceptor sendBuffer, _ := newSendBuffer(n.size) n.streamsMu.Lock() - n.streams[info.SSRC] = &localStream{sendBuffer: sendBuffer, rtpWriter: writer} + n.streams[info.SSRC] = &localStream{ + sendBuffer: sendBuffer, + rtpWriter: writer, + rtxSsrc: info.SSRCRetransmission, + rtxPayloadType: info.PayloadTypeRetransmission, + rtxSequencer: rtp.NewRandomSequencer(), + } n.streamsMu.Unlock() return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { @@ -139,8 +151,43 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { for i := range nack.Nacks { nack.Nacks[i].Range(func(seq uint16) bool { if p := stream.sendBuffer.get(seq); p != nil { - if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { - n.log.Warnf("failed resending nacked packet: %+v", err) + if stream.rtxSsrc != 0 { + // Store the original sequence number and rewrite the sequence number. + originalSequenceNumber := p.Header().SequenceNumber + p.Header().SequenceNumber = stream.rtxSequencer.NextSequenceNumber() + + // Rewrite the SSRC. + p.Header().SSRC = stream.rtxSsrc + // Rewrite the payload type. + p.Header().PayloadType = stream.rtxPayloadType + + // Remove padding if present. + paddingLength := 0 + originPayload := p.Payload() + if p.Header().Padding { + paddingLength = int(originPayload[len(originPayload)-1]) + p.Header().Padding = false + } + + // Write the original sequence number at the beginning of the payload. + payload := make([]byte, 2) + binary.BigEndian.PutUint16(payload, originalSequenceNumber) + payload = append(payload, originPayload[:len(originPayload)-paddingLength]...) + + // Send RTX packet. + if _, err := stream.rtpWriter.Write(p.Header(), payload, interceptor.Attributes{}); err != nil { + n.log.Warnf("failed sending rtx packet: %+v", err) + } + + // Resore the Padding and SSRC. + if paddingLength > 0 { + p.Header().Padding = true + } + p.Header().SequenceNumber = originalSequenceNumber + } else { + if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { + n.log.Warnf("failed resending nacked packet: %+v", err) + } } p.Release() } diff --git a/pkg/nack/responder_interceptor_test.go b/pkg/nack/responder_interceptor_test.go index 9eb5b23..68e88d1 100644 --- a/pkg/nack/responder_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -4,6 +4,7 @@ package nack import ( + "encoding/binary" "testing" "time" @@ -231,3 +232,60 @@ func TestResponderInterceptor_StreamFilter(t *testing.T) { case <-time.After(10 * time.Millisecond): } } + +func TestResponderInterceptor_RFC4588(t *testing.T) { + f, err := NewResponderInterceptor() + require.NoError(t, err) + + i, err := f.NewInterceptor("") + require.NoError(t, err) + + stream := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 1, + SSRCRetransmission: 2, + PayloadTypeRetransmission: 2, + RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, + }, i) + defer func() { + require.NoError(t, stream.Close()) + }() + + for _, seqNum := range []uint16{10, 11, 12, 14, 15} { + require.NoError(t, stream.WriteRTP(&rtp.Packet{Header: rtp.Header{SequenceNumber: seqNum}})) + + select { + case p := <-stream.WrittenRTP(): + require.Equal(t, seqNum, p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + } + + stream.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 1, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, + }) + + // seq number 13 was never sent, so it can't be resent + for _, seqNum := range []uint16{11, 12, 15} { + select { + case p := <-stream.WrittenRTP(): + require.Equal(t, uint32(2), p.SSRC) + require.Equal(t, uint8(2), p.PayloadType) + require.Equal(t, binary.BigEndian.Uint16(p.Payload), seqNum) + case <-time.After(10 * time.Millisecond): + t.Fatal("written rtp packet not found") + } + } + + select { + case p := <-stream.WrittenRTP(): + t.Errorf("no more rtp packets expected, found sequence number: %v", p.SequenceNumber) + case <-time.After(10 * time.Millisecond): + } +}