Skip to content

Commit

Permalink
Use SyncPool for NACKs with distinct SSRC
Browse files Browse the repository at this point in the history
NACK storage would use a sync.Pool to reduce memory pressure. The new
distinict SSRC feature did not properly use this pool. This change
updates the RTPBuffer to use it for both paths.

Fixes #281
  • Loading branch information
nithu1991 authored and Sean-Der committed Oct 28, 2024
1 parent d34a3c5 commit 94457ee
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 14 deletions.
36 changes: 22 additions & 14 deletions internal/rtpbuffer/packet_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"github.com/pion/rtp"
)

const rtxSsrcByteLength = 2

// PacketFactory allows custom logic around the handle of RTP Packets before they added to the RTPBuffer.
// The NoOpPacketFactory doesn't copy packets, while the RetainablePacket will take a copy before adding
type PacketFactory interface {
Expand Down Expand Up @@ -68,32 +70,38 @@ func (m *PacketFactoryCopy) NewPacket(header *rtp.Header, payload []byte, rtxSsr
if !ok {
return nil, errFailedToCastPayloadPool
}

size := copy(*p.buffer, payload)
p.payload = (*p.buffer)[:size]
if rtxSsrc != 0 && rtxPayloadType != 0 {
size := copy((*p.buffer)[rtxSsrcByteLength:], payload)
p.payload = (*p.buffer)[:size+rtxSsrcByteLength]
} else {
size := copy(*p.buffer, payload)
p.payload = (*p.buffer)[:size]
}
}

if rtxSsrc != 0 && rtxPayloadType != 0 {
// Store the original sequence number and rewrite the sequence number.
originalSequenceNumber := p.header.SequenceNumber
p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()
if payload == nil {
p.buffer, ok = m.payloadPool.Get().(*[]byte)
if !ok {
return nil, errFailedToCastPayloadPool
}
p.payload = (*p.buffer)[:rtxSsrcByteLength]
}
// Write the original sequence number at the beginning of the payload.
binary.BigEndian.PutUint16(p.payload, p.header.SequenceNumber)

// Rewrite the SSRC.
p.header.SSRC = rtxSsrc
// Rewrite the payload type.
p.header.PayloadType = rtxPayloadType

// Rewrite the sequence number.
p.header.SequenceNumber = m.rtxSequencer.NextSequenceNumber()
// Remove padding if present.
paddingLength := 0
if p.header.Padding && p.payload != nil && len(p.payload) > 0 {
paddingLength = int(p.payload[len(p.payload)-1])
paddingLength := int(p.payload[len(p.payload)-1])
p.header.Padding = false
p.payload = (*p.buffer)[:len(p.payload)-paddingLength]
}

// Write the original sequence number at the beginning of the payload.
payload := make([]byte, 2)
binary.BigEndian.PutUint16(payload, originalSequenceNumber)
p.payload = append(payload, p.payload[:len(p.payload)-paddingLength]...)
}

return p, nil
Expand Down
117 changes: 117 additions & 0 deletions internal/rtpbuffer/rtpbuffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,66 @@ func TestRTPBuffer(t *testing.T) {
}
}

func TestRTPBuffer_WithRTX(t *testing.T) {
pm := NewPacketFactoryCopy()
for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} {
start := start

sb, err := NewRTPBuffer(8)
require.NoError(t, err)

add := func(nums ...uint16) {
for _, n := range nums {
seq := start + n
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: seq, PayloadType: 2}, []byte("originalcontent"), 1, 1)
require.NoError(t, err)
sb.Add(pkt)
}
}

assertGet := func(nums ...uint16) {
t.Helper()
for _, n := range nums {
seq := start + n
packet := sb.Get(seq)
if packet == nil {
t.Errorf("packet not found: %d", seq)
continue
}
if packet.Header().SSRC != 1 && packet.Header().PayloadType != 1 {
t.Errorf("packet for %d returned with incorrect SSRC : %d and PayloadType: %d", seq, packet.Header().SSRC, packet.Header().PayloadType)
}
packet.Release()
}
}
assertNOTGet := func(nums ...uint16) {
t.Helper()
for _, n := range nums {
seq := start + n
packet := sb.Get(seq)
if packet != nil {
t.Errorf("packet found for %d: %d", seq, packet.Header().SequenceNumber)
}
}
}

add(0, 1, 2, 3, 4, 5, 6, 7)
assertGet(0, 1, 2, 3, 4, 5, 6, 7)

add(8)
assertGet(8)
assertNOTGet(0)

add(10)
assertGet(10)
assertNOTGet(1, 2, 9)

add(22)
assertGet(22)
assertNOTGet(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
}
}

func TestRTPBuffer_Overridden(t *testing.T) {
// override original packet content and get
pm := NewPacketFactoryCopy()
Expand Down Expand Up @@ -98,3 +158,60 @@ func TestRTPBuffer_Overridden(t *testing.T) {

require.Nil(t, sb.Get(1))
}

func TestRTPBuffer_Overridden_WithRTX_AND_Padding(t *testing.T) {
// override original packet content and get
pm := NewPacketFactoryCopy()
sb, err := NewRTPBuffer(1)
require.NoError(t, err)
require.Equal(t, uint16(1), sb.size)

originalBytes := []byte("originalContent\x01")
pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1, Padding: true, SSRC: 2, PayloadType: 3}, originalBytes, 1, 1)
require.NoError(t, err)
sb.Add(pkt)

// change payload
copy(originalBytes, "altered")
retrieved := sb.Get(1)
require.NotNil(t, retrieved)
require.Equal(t, "\x00\x01originalContent", string(retrieved.Payload()))
retrieved.Release()
require.Equal(t, 1, retrieved.count)

// ensure original packet is released
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, originalBytes, 1, 1)
require.NoError(t, err)
sb.Add(pkt)
require.Equal(t, 0, retrieved.count)

require.Nil(t, sb.Get(1))
}

func TestRTPBuffer_Overridden_WithRTX_NILPayload(t *testing.T) {
// override original packet content and get
pm := NewPacketFactoryCopy()
sb, err := NewRTPBuffer(1)
require.NoError(t, err)
require.Equal(t, uint16(1), sb.size)

pkt, err := pm.NewPacket(&rtp.Header{SequenceNumber: 1}, nil, 1, 1)
require.NoError(t, err)
sb.Add(pkt)

// change payload

retrieved := sb.Get(1)
require.NotNil(t, retrieved)
require.Equal(t, "\x00\x01", string(retrieved.Payload()))
retrieved.Release()
require.Equal(t, 1, retrieved.count)

// ensure original packet is released
pkt, err = pm.NewPacket(&rtp.Header{SequenceNumber: 2}, []byte("altered"), 1, 1)
require.NoError(t, err)
sb.Add(pkt)
require.Equal(t, 0, retrieved.count)

require.Nil(t, sb.Get(1))
}

0 comments on commit 94457ee

Please sign in to comment.