diff --git a/internal/rtpbuffer/retainable_packet.go b/internal/rtpbuffer/retainable_packet.go index 5a9afd7..1425050 100644 --- a/internal/rtpbuffer/retainable_packet.go +++ b/internal/rtpbuffer/retainable_packet.go @@ -33,6 +33,14 @@ func (p *RetainablePacket) Payload() []byte { return p.payload } +// Packet returns a RTP Packet for a RetainablePacket +func (p *RetainablePacket) Packet() *rtp.Packet { + return &rtp.Packet{ + Header: *p.Header(), + Payload: p.Payload(), + } +} + // Retain increases the reference count of the RetainablePacket func (p *RetainablePacket) Retain() error { p.countMu.Lock() @@ -46,10 +54,15 @@ func (p *RetainablePacket) Retain() error { } // Release decreases the reference count of the RetainablePacket and frees if needed -func (p *RetainablePacket) Release() { +func (p *RetainablePacket) Release(force bool) { p.countMu.Lock() defer p.countMu.Unlock() - p.count-- + + if !force { + p.count-- + } else { + p.count = 0 + } if p.count == 0 { // release back to pool @@ -59,3 +72,10 @@ func (p *RetainablePacket) Release() { p.payload = nil } } + +func (p *RetainablePacket) getCount() int { + p.countMu.Lock() + defer p.countMu.Unlock() + + return p.count +} diff --git a/internal/rtpbuffer/rtpbuffer.go b/internal/rtpbuffer/rtpbuffer.go index 9253507..8d5947e 100644 --- a/internal/rtpbuffer/rtpbuffer.go +++ b/internal/rtpbuffer/rtpbuffer.go @@ -63,7 +63,7 @@ func (r *RTPBuffer) Add(packet *RetainablePacket) { idx := i % r.size prevPacket := r.packets[idx] if prevPacket != nil { - prevPacket.Release() + prevPacket.Release(false) } r.packets[idx] = nil } @@ -72,7 +72,7 @@ func (r *RTPBuffer) Add(packet *RetainablePacket) { idx := seq % r.size prevPacket := r.packets[idx] if prevPacket != nil { - prevPacket.Release() + prevPacket.Release(false) } r.packets[idx] = packet r.lastAdded = seq @@ -101,3 +101,36 @@ func (r *RTPBuffer) Get(seq uint16) *RetainablePacket { } return pkt } + +// GetTimestamp returns a RetainablePacket for the requested timestamp +func (r *RTPBuffer) GetTimestamp(timestamp uint32) *RetainablePacket { + for i := range r.packets { + pkt := r.packets[i] + if pkt != nil && pkt.Header() != nil && pkt.Header().Timestamp == timestamp { + if err := pkt.Retain(); err != nil { + return nil + } + + return pkt + } + } + return nil +} + +// Length returns the count of valid RetainablePackets in the RTPBuffer +func (r *RTPBuffer) Length() (length uint16) { + for i := range r.packets { + if r.packets[i] != nil && r.packets[i].getCount() != 0 { + length++ + } + } + + return +} + +// Clear erases all the packets in the RTPBuffer +func (r *RTPBuffer) Clear() { + r.lastAdded = 0 + r.started = false + r.packets = make([]*RetainablePacket, r.size) +} diff --git a/internal/rtpbuffer/rtpbuffer_test.go b/internal/rtpbuffer/rtpbuffer_test.go index 746fb5c..42828b3 100644 --- a/internal/rtpbuffer/rtpbuffer_test.go +++ b/internal/rtpbuffer/rtpbuffer_test.go @@ -39,7 +39,7 @@ func TestRTPBuffer(t *testing.T) { if packet.Header().SequenceNumber != seq { t.Errorf("packet for %d returned with incorrect SequenceNumber: %d", seq, packet.Header().SequenceNumber) } - packet.Release() + packet.Release(false) } } assertNOTGet := func(nums ...uint16) { @@ -87,7 +87,7 @@ func TestRTPBuffer_Overridden(t *testing.T) { retrieved := sb.Get(1) require.NotNil(t, retrieved) require.Equal(t, "originalContent", string(retrieved.Payload())) - retrieved.Release() + retrieved.Release(false) require.Equal(t, 1, retrieved.count) // ensure original packet is released diff --git a/pkg/jitterbuffer/jitter_buffer.go b/pkg/jitterbuffer/jitter_buffer.go index 9cf9731..211df03 100644 --- a/pkg/jitterbuffer/jitter_buffer.go +++ b/pkg/jitterbuffer/jitter_buffer.go @@ -8,6 +8,7 @@ package jitterbuffer import ( "errors" + "github.com/pion/interceptor/internal/rtpbuffer" "github.com/pion/rtp" ) @@ -20,8 +21,12 @@ type Event string var ( // ErrBufferUnderrun is returned when the buffer has no items ErrBufferUnderrun = errors.New("invalid Peek: Empty jitter buffer") + // ErrPopWhileBuffering is returned if a jitter buffer is not in a playback state ErrPopWhileBuffering = errors.New("attempt to pop while buffering") + + // ErrNotFound is returned when a packet does not exist for a SequenceNumber + ErrNotFound = errors.New("packet with sequence number was not found") ) const ( @@ -63,7 +68,8 @@ type ( // order, and allows removing in either sequence number order or via a // provided timestamp type JitterBuffer struct { - packets *PriorityQueue + packets *rtpbuffer.RTPBuffer + packetFactory rtpbuffer.PacketFactoryNoOp minStartCount uint16 lastSequence uint16 playoutHead uint16 @@ -88,11 +94,16 @@ type Stats struct { // New will initialize a jitter buffer and its associated statistics func New(opts ...Option) *JitterBuffer { + rtpBuffer, err := rtpbuffer.NewRTPBuffer(rtpbuffer.Uint16SizeHalf) + if err != nil || rtpBuffer == nil { + return nil + } + jb := &JitterBuffer{ state: Buffering, stats: Stats{0, 0, 0}, minStartCount: 50, - packets: NewQueue(), + packets: rtpBuffer, listeners: make(map[Event][]EventListener), } @@ -142,18 +153,20 @@ func (jb *JitterBuffer) updateStats(lastPktSeqNo uint16) { // the data so if the memory is expected to be reused, the caller should // take this in to account and pass a copy of the packet they wish to buffer func (jb *JitterBuffer) Push(packet *rtp.Packet) { - if jb.packets.Length() == 0 { + if packetsLen := jb.packets.Length(); packetsLen == 0 { + if !jb.playoutReady { + jb.playoutHead = packet.SequenceNumber + } + jb.emit(StartBuffering) - } - if jb.packets.Length() > 100 { + } else if packetsLen > 100 { jb.stats.overflowCount++ jb.emit(BufferOverflow) } - if !jb.playoutReady && jb.packets.Length() == 0 { - jb.playoutHead = packet.SequenceNumber - } + jb.updateStats(packet.SequenceNumber) - jb.packets.Push(packet, packet.SequenceNumber) + retainablePkt, _ := jb.packetFactory.NewPacket(&packet.Header, packet.Payload, 0, 0) + jb.packets.Add(retainablePkt) jb.updateState() } @@ -184,25 +197,14 @@ func (jb *JitterBuffer) Peek(playoutHead bool) (*rtp.Packet, error) { return nil, ErrBufferUnderrun } if playoutHead && jb.state == Emitting { - return jb.packets.Find(jb.playoutHead) + return jb.PeekAtSequence(jb.playoutHead) } - return jb.packets.Find(jb.lastSequence) + return jb.PeekAtSequence(jb.lastSequence) } // Pop an RTP packet from the jitter buffer at the current playout head func (jb *JitterBuffer) Pop() (*rtp.Packet, error) { - if jb.state != Emitting { - return nil, ErrPopWhileBuffering - } - packet, err := jb.packets.PopAt(jb.playoutHead) - if err != nil { - jb.stats.underflowCount++ - jb.emit(BufferUnderflow) - return nil, err - } - jb.playoutHead = (jb.playoutHead + 1) - jb.updateState() - return packet, nil + return jb.PopAtSequence(jb.playoutHead) } // PopAtSequence will pop an RTP packet from the jitter buffer at the specified Sequence @@ -210,41 +212,45 @@ func (jb *JitterBuffer) PopAtSequence(sq uint16) (*rtp.Packet, error) { if jb.state != Emitting { return nil, ErrPopWhileBuffering } - packet, err := jb.packets.PopAt(sq) - if err != nil { + retainablePacket := jb.packets.Get(sq) + if retainablePacket == nil { jb.stats.underflowCount++ jb.emit(BufferUnderflow) - return nil, err + return nil, ErrNotFound } + + defer retainablePacket.Release(true) jb.playoutHead = (jb.playoutHead + 1) jb.updateState() - return packet, nil + return retainablePacket.Packet(), nil } // PeekAtSequence will return an RTP packet from the jitter buffer at the specified Sequence // without removing it from the buffer func (jb *JitterBuffer) PeekAtSequence(sq uint16) (*rtp.Packet, error) { - packet, err := jb.packets.Find(sq) - if err != nil { - return nil, err + retainablePacket := jb.packets.Get(sq) + if retainablePacket == nil { + return nil, ErrNotFound } - return packet, nil + return retainablePacket.Packet(), nil } // PopAtTimestamp pops an RTP packet from the jitter buffer with the provided timestamp // Call this method repeatedly to drain the buffer at the timestamp -func (jb *JitterBuffer) PopAtTimestamp(ts uint32) (*rtp.Packet, error) { +func (jb *JitterBuffer) PopAtTimestamp(ts uint32) (*rtp.Packet, error) { //nolint: revive if jb.state != Emitting { return nil, ErrPopWhileBuffering } - packet, err := jb.packets.PopAtTimestamp(ts) - if err != nil { + retainablePacket := jb.packets.GetTimestamp(ts) + if retainablePacket == nil { jb.stats.underflowCount++ jb.emit(BufferUnderflow) - return nil, err + return nil, ErrNotFound } + + defer retainablePacket.Release(true) jb.updateState() - return packet, nil + return retainablePacket.Packet(), nil } // Clear will empty the buffer and optionally reset the state diff --git a/pkg/jitterbuffer/jitter_buffer_test.go b/pkg/jitterbuffer/jitter_buffer_test.go index f163a8f..11fca03 100644 --- a/pkg/jitterbuffer/jitter_buffer_test.go +++ b/pkg/jitterbuffer/jitter_buffer_test.go @@ -132,7 +132,7 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(pkt.SequenceNumber, uint16(5002)) assert.Equal(err, nil) for i := 0; i < 100; i++ { - sqnum := uint16((math.MaxUint16 - 32 + i)) + sqnum := uint16((6000 + i)) jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) } pkt, err = jb.Peek(true) diff --git a/pkg/jitterbuffer/priority_queue.go b/pkg/jitterbuffer/priority_queue.go deleted file mode 100644 index 11a8679..0000000 --- a/pkg/jitterbuffer/priority_queue.go +++ /dev/null @@ -1,194 +0,0 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -package jitterbuffer - -import ( - "errors" - - "github.com/pion/rtp" -) - -// PriorityQueue provides a linked list sorting of RTP packets by SequenceNumber -type PriorityQueue struct { - next *node - length uint16 -} - -type node struct { - val *rtp.Packet - next *node - prev *node - priority uint16 -} - -var ( - // ErrInvalidOperation may be returned if a Pop or Find operation is performed on an empty queue - ErrInvalidOperation = errors.New("attempt to find or pop on an empty list") - // ErrNotFound will be returned if the packet cannot be found in the queue - ErrNotFound = errors.New("priority not found") -) - -// NewQueue will create a new PriorityQueue whose order relies on monotonically -// increasing Sequence Number, wrapping at MaxUint16, so -// a packet with sequence number MaxUint16 - 1 will be after 0 -func NewQueue() *PriorityQueue { - return &PriorityQueue{ - next: nil, - length: 0, - } -} - -func newNode(val *rtp.Packet, priority uint16) *node { - return &node{ - val: val, - prev: nil, - next: nil, - priority: priority, - } -} - -// Find a packet in the queue with the provided sequence number, -// regardless of position (the packet is retained in the queue) -func (q *PriorityQueue) Find(sqNum uint16) (*rtp.Packet, error) { - next := q.next - for next != nil { - if next.priority == sqNum { - return next.val, nil - } - next = next.next - } - - return nil, ErrNotFound -} - -// Push will insert a packet in to the queue in order of sequence number -func (q *PriorityQueue) Push(val *rtp.Packet, priority uint16) { - newPq := newNode(val, priority) - if q.next == nil { - q.next = newPq - q.length++ - return - } - if priority < q.next.priority { - newPq.next = q.next - q.next.prev = newPq - q.next = newPq - q.length++ - return - } - head := q.next - prev := q.next - for head != nil { - if priority <= head.priority { - break - } - prev = head - head = head.next - } - if head == nil { - if prev != nil { - prev.next = newPq - } - newPq.prev = prev - } else { - newPq.next = head - newPq.prev = prev - if prev != nil { - prev.next = newPq - } - head.prev = newPq - } - q.length++ -} - -// Length will get the total length of the queue -func (q *PriorityQueue) Length() uint16 { - return q.length -} - -// Pop removes the first element from the queue, regardless -// sequence number -func (q *PriorityQueue) Pop() (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation - } - val := q.next.val - q.next.val = nil - q.length-- - q.next = q.next.next - return val, nil -} - -// PopAt removes an element at the specified sequence number (priority) -func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation - } - if q.next.priority == sqNum { - val := q.next.val - q.next.val = nil - q.next = q.next.next - q.length-- - return val, nil - } - pos := q.next - prev := q.next.prev - for pos != nil { - if pos.priority == sqNum { - val := pos.val - pos.val = nil - prev.next = pos.next - if prev.next != nil { - prev.next.prev = prev - } - q.length-- - return val, nil - } - prev = pos - pos = pos.next - } - return nil, ErrNotFound -} - -// PopAtTimestamp removes and returns a packet at the given RTP Timestamp, regardless -// sequence number order -func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation - } - if q.next.val.Timestamp == timestamp { - val := q.next.val - q.next.val = nil - q.next = q.next.next - q.length-- - return val, nil - } - pos := q.next - prev := q.next.prev - for pos != nil { - if pos.val.Timestamp == timestamp { - val := pos.val - pos.val = nil - prev.next = pos.next - if prev.next != nil { - prev.next.prev = prev - } - q.length-- - return val, nil - } - prev = pos - pos = pos.next - } - return nil, ErrNotFound -} - -// Clear will empty a PriorityQueue -func (q *PriorityQueue) Clear() { - next := q.next - q.length = 0 - for next != nil { - next.prev = nil - next = next.next - } -} diff --git a/pkg/jitterbuffer/priority_queue_test.go b/pkg/jitterbuffer/priority_queue_test.go deleted file mode 100644 index 8b8d23e..0000000 --- a/pkg/jitterbuffer/priority_queue_test.go +++ /dev/null @@ -1,184 +0,0 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -package jitterbuffer - -import ( - "runtime" - "sync/atomic" - "testing" - "time" - - "github.com/pion/rtp" - "github.com/stretchr/testify/assert" -) - -func TestPriorityQueue(t *testing.T) { - assert := assert.New(t) - - t.Run("Appends packets in order", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - assert.Equal(q.next.next.val, pkt2) - assert.Equal(q.next.priority, uint16(5000)) - assert.Equal(q.next.next.priority, uint16(5004)) - }) - - t.Run("Appends many in order", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) - } - assert.Equal(uint16(100), q.Length()) - last := (*node)(nil) - cur := q.next - for cur != nil { - last = cur - cur = cur.next - if cur != nil { - assert.Equal(cur.priority, last.priority+1) - } - } - assert.Equal(q.next.priority, uint16(5012)) - assert.Equal(last.priority, uint16(5012+99)) - }) - - t.Run("Can remove an element", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) - } - popped, _ := q.Pop() - assert.Equal(popped.SequenceNumber, uint16(5000)) - _, _ = q.Pop() - nextPop, _ := q.Pop() - assert.Equal(nextPop.SequenceNumber, uint16(5012)) - }) - - t.Run("Appends in order", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) - } - assert.Equal(uint16(100), q.Length()) - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt, pkt.SequenceNumber) - assert.Equal(pkt, q.next.val) - assert.Equal(uint16(101), q.Length()) - assert.Equal(q.next.priority, uint16(5000)) - }) - - t.Run("Can find", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) - } - pkt, err := q.Find(5012) - assert.Equal(pkt.SequenceNumber, uint16(5012)) - assert.Equal(err, nil) - }) - - t.Run("Updates the length when PopAt* are called", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) - } - assert.Equal(uint16(102), q.Length()) - popped, _ := q.PopAt(uint16(5012)) - assert.Equal(popped.SequenceNumber, uint16(5012)) - assert.Equal(uint16(101), q.Length()) - - popped, err := q.PopAtTimestamp(uint32(500)) - assert.Equal(popped.SequenceNumber, uint16(5000)) - assert.Equal(uint16(100), q.Length()) - assert.Equal(err, nil) - }) -} - -func TestPriorityQueue_Find(t *testing.T) { - packets := NewQueue() - - packets.Push(&rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 1000, - Timestamp: 5, - SSRC: 5, - }, - Payload: []uint8{0xA}, - }, 1000) - - _, err := packets.PopAt(1000) - assert.NoError(t, err) - - _, err = packets.Find(1001) - assert.Error(t, err) -} - -func TestPriorityQueue_Clean(t *testing.T) { - packets := NewQueue() - packets.Clear() - packets.Push(&rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 1000, - Timestamp: 5, - SSRC: 5, - }, - Payload: []uint8{0xA}, - }, 1000) - assert.EqualValues(t, 1, packets.Length()) - packets.Clear() -} - -func TestPriorityQueue_Unreference(t *testing.T) { - packets := NewQueue() - - var refs int64 - finalizer := func(*rtp.Packet) { - atomic.AddInt64(&refs, -1) - } - - numPkts := 100 - for i := 0; i < numPkts; i++ { - atomic.AddInt64(&refs, 1) - seq := uint16(i) - p := rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: seq, - Timestamp: uint32(i + 42), - }, - Payload: []byte{byte(i)}, - } - runtime.SetFinalizer(&p, finalizer) - packets.Push(&p, seq) - } - for i := 0; i < numPkts-1; i++ { - switch i % 3 { - case 0: - packets.Pop() //nolint - case 1: - packets.PopAt(uint16(i)) //nolint - case 2: - packets.PopAtTimestamp(uint32(i + 42)) //nolint - } - } - - runtime.GC() - time.Sleep(10 * time.Millisecond) - - remainedRefs := atomic.LoadInt64(&refs) - runtime.KeepAlive(packets) - - // only the last packet should be still referenced - assert.Equal(t, int64(1), remainedRefs) -} diff --git a/pkg/jitterbuffer/receiver_interceptor.go b/pkg/jitterbuffer/receiver_interceptor.go index bd7d544..3d639a2 100644 --- a/pkg/jitterbuffer/receiver_interceptor.go +++ b/pkg/jitterbuffer/receiver_interceptor.go @@ -52,10 +52,11 @@ func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, // arriving) quickly enough. type ReceiverInterceptor struct { interceptor.NoOp - buffer *JitterBuffer - wg sync.WaitGroup - close chan struct{} - log logging.LeveledLogger + buffer *JitterBuffer + bufferMu sync.Mutex + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger } // NewInterceptor returns a new InterceptorFactory @@ -77,6 +78,9 @@ func (i *ReceiverInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader return 0, nil, err } + i.bufferMu.Lock() + defer i.bufferMu.Unlock() + i.buffer.Push(packet) if i.buffer.state == Emitting { newPkt, err := i.buffer.Pop() @@ -92,7 +96,10 @@ func (i *ReceiverInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader // UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. func (i *ReceiverInterceptor) UnbindRemoteStream(_ *interceptor.StreamInfo) { + i.bufferMu.Lock() defer i.wg.Wait() + defer i.bufferMu.Unlock() + i.buffer.Clear(true) } diff --git a/pkg/nack/responder_interceptor.go b/pkg/nack/responder_interceptor.go index 45a3777..8dc9e30 100644 --- a/pkg/nack/responder_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -142,7 +142,7 @@ func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { if _, err := stream.rtpWriter.Write(p.Header(), p.Payload(), interceptor.Attributes{}); err != nil { n.log.Warnf("failed resending nacked packet: %+v", err) } - p.Release() + p.Release(false) } return true