From d304eddee940a89e7c894aad529cfdc392860b29 Mon Sep 17 00:00:00 2001 From: Rob Elsner Date: Sun, 21 Apr 2024 08:55:05 -0400 Subject: [PATCH] JitterBuffer: Fix queue not properly decrementing packet count in some instances --- pkg/jitterbuffer/jitter_buffer_test.go | 6 ++++++ pkg/jitterbuffer/priority_queue.go | 4 ++++ pkg/jitterbuffer/priority_queue_test.go | 20 ++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/pkg/jitterbuffer/jitter_buffer_test.go b/pkg/jitterbuffer/jitter_buffer_test.go index 0ed73023..cb45a40d 100644 --- a/pkg/jitterbuffer/jitter_buffer_test.go +++ b/pkg/jitterbuffer/jitter_buffer_test.go @@ -45,6 +45,10 @@ func TestJitterBuffer(t *testing.T) { }) t.Run("Appends packets and begins playout", func(*testing.T) { jb := New(WithMinimumPacketCount(1)) + events := make([]Event, 0) + jb.Listen(BeginPlayback, func(event Event, jb *JitterBuffer) { + events = append(events, BeginPlayback) + }) for i := 0; i < 2; i++ { jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) } @@ -54,6 +58,8 @@ func TestJitterBuffer(t *testing.T) { head, err := jb.Pop() assert.Equal(head.SequenceNumber, uint16(5012)) assert.Equal(err, nil) + assert.Equal(1, len(events)) + assert.Equal(Event(BeginPlayback), events[0]) }) t.Run("Wraps playout correctly", func(*testing.T) { diff --git a/pkg/jitterbuffer/priority_queue.go b/pkg/jitterbuffer/priority_queue.go index 366ff10a..50e70c81 100644 --- a/pkg/jitterbuffer/priority_queue.go +++ b/pkg/jitterbuffer/priority_queue.go @@ -127,6 +127,7 @@ func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { if q.next.priority == sqNum { val := q.next.val q.next = q.next.next + q.length-- return val, nil } pos := q.next @@ -138,6 +139,7 @@ func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { if prev.next != nil { prev.next.prev = prev } + q.length-- return val, nil } prev = pos @@ -155,6 +157,7 @@ func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { if q.next.val.Timestamp == timestamp { val := q.next.val q.next = q.next.next + q.length-- return val, nil } pos := q.next @@ -166,6 +169,7 @@ func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { if prev.next != nil { prev.next.prev = prev } + q.length-- return val, nil } prev = pos diff --git a/pkg/jitterbuffer/priority_queue_test.go b/pkg/jitterbuffer/priority_queue_test.go index 44d0e010..d28cc720 100644 --- a/pkg/jitterbuffer/priority_queue_test.go +++ b/pkg/jitterbuffer/priority_queue_test.go @@ -81,6 +81,26 @@ func TestPriorityQueue(t *testing.T) { assert.Equal(pkt.SequenceNumber, uint16(5012)) assert.Equal(err, nil) }) + + t.Run("Updates the length when PopAt* are called", func(*testing.T) { + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} + q := NewQueue() + q.Push(pkt, pkt.SequenceNumber) + pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} + q.Push(pkt2, pkt2.SequenceNumber) + for i := 0; i < 100; i++ { + q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + } + assert.Equal(uint16(102), q.Length()) + popped, _ := q.PopAt(uint16(5012)) + assert.Equal(popped.SequenceNumber, uint16(5012)) + assert.Equal(uint16(101), q.Length()) + + popped, err := q.PopAtTimestamp(uint32(500)) + assert.Equal(popped.SequenceNumber, uint16(5000)) + assert.Equal(uint16(100), q.Length()) + assert.Equal(err, nil) + }) } func TestPriorityQueue_Find(t *testing.T) {