diff --git a/pkg/jitterbuffer/jitter_buffer.go b/pkg/jitterbuffer/jitter_buffer.go index b66d5014..56cb6c9e 100644 --- a/pkg/jitterbuffer/jitter_buffer.go +++ b/pkg/jitterbuffer/jitter_buffer.go @@ -65,14 +65,15 @@ type ( // order, and allows removing in either sequence number order or via a // provided timestamp type JitterBuffer struct { - packets *PriorityQueue - lastSequence uint16 - playoutHead uint16 - playoutReady bool - state State - stats Stats - listeners map[Event][]EventListener - mutex sync.Mutex + packets *PriorityQueue + minStartCount uint16 + lastSequence uint16 + playoutHead uint16 + playoutReady bool + state State + stats Stats + listeners map[Event][]EventListener + mutex sync.Mutex } // Stats Track interesting statistics for the life of this JitterBuffer @@ -90,13 +91,21 @@ type Stats struct { // New will initialize a jitter buffer and its associated statistics func New(opts ...Option) *JitterBuffer { - jb := &JitterBuffer{state: Buffering, stats: Stats{0, 0, 0}, packets: NewQueue(), listeners: make(map[Event][]EventListener)} + jb := &JitterBuffer{state: Buffering, stats: Stats{0, 0, 0}, minStartCount: 50, packets: NewQueue(), listeners: make(map[Event][]EventListener)} for _, o := range opts { o(jb) } return jb } +// WithMinimumPacketCount will set the required number of packets to be received before +// any attempt to pop a packet can succeed +func WithMinimumPacketCount(count uint16) Option { + return func(jb *JitterBuffer) { + jb.minStartCount = count + } +} + // Listen will register an event listener // The jitter buffer may emit events correspnding, interested listerns should // look at Event for available events @@ -142,7 +151,7 @@ func (jb *JitterBuffer) emit(event Event) { func (jb *JitterBuffer) updateState() { // For now, we only look at the number of packets captured in the play buffer - if jb.packets.Length() >= 50 && jb.state == Buffering { + if jb.packets.Length() >= jb.minStartCount && jb.state == Buffering { jb.state = Emitting jb.playoutReady = true jb.emit(BeginPlayback) @@ -186,6 +195,36 @@ func (jb *JitterBuffer) Pop() (*rtp.Packet, error) { return packet, nil } +// PopAtSequence will pop an RTP packet from the jitter buffer at the specified Sequence +func (jb *JitterBuffer) PopAtSequence(sq uint16) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + if jb.state != Emitting { + return nil, ErrPopWhileBuffering + } + packet, err := jb.packets.PopAt(sq) + if err != nil { + jb.stats.underflowCount++ + jb.emit(BufferUnderflow) + return (*rtp.Packet)(nil), err + } + jb.playoutHead = (jb.playoutHead + 1) % math.MaxUint16 + jb.updateState() + return packet, nil +} + +// PeekAtSequence will return an RTP packet from the jitter buffer at the specified Sequence +// without removing it from the buffer +func (jb *JitterBuffer) PeekAtSequence(sq uint16) (*rtp.Packet, error) { + jb.mutex.Lock() + defer jb.mutex.Unlock() + packet, err := jb.packets.Find(sq) + if err != nil { + return (*rtp.Packet)(nil), err + } + return packet, nil +} + // PopAtTimestamp pops an RTP packet from the jitter buffer with the provided timestamp // Call this method repeatedly to drain the buffer at the timestamp func (jb *JitterBuffer) PopAtTimestamp(ts uint32) (*rtp.Packet, error) { diff --git a/pkg/jitterbuffer/jitter_buffer_test.go b/pkg/jitterbuffer/jitter_buffer_test.go index ff97d2ca..b803a98f 100644 --- a/pkg/jitterbuffer/jitter_buffer_test.go +++ b/pkg/jitterbuffer/jitter_buffer_test.go @@ -29,7 +29,6 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(jb.packets.Length(), uint16(4)) assert.Equal(jb.lastSequence, uint16(5012)) }) - t.Run("Appends packets and begins playout", func(t *testing.T) { jb := New() for i := 0; i < 100; i++ { @@ -42,6 +41,18 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(head.SequenceNumber, uint16(5012)) assert.Equal(err, nil) }) + t.Run("Appends packets and begins playout", func(t *testing.T) { + jb := New(WithMinimumPacketCount(1)) + for i := 0; i < 2; i++ { + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + assert.Equal(jb.packets.Length(), uint16(2)) + assert.Equal(jb.state, Emitting) + assert.Equal(jb.playoutHead, uint16(5012)) + head, err := jb.Pop() + assert.Equal(head.SequenceNumber, uint16(5012)) + assert.Equal(err, nil) + }) t.Run("Wraps playout correctly", func(t *testing.T) { jb := New() for i := 0; i < 100; i++ { @@ -99,6 +110,20 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(pkt.SequenceNumber, uint16(5000)) assert.Equal(err, nil) }) +t.Run("Pops at sequence with an invalid sequence number", func(t *testing.T) { + jb := New() + for i := 0; i < 50; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + head, err := jb.PopAtSequence(uint16(9000)) + assert.Equal(head, (*rtp.Packet)(nil)) + assert.NotEqual(err, nil) + }) t.Run("Pops at timestamp with multiple packets", func(t *testing.T) { jb := New() for i := 0; i < 50; i++ { @@ -120,4 +145,25 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) assert.Equal(err, nil) }) + t.Run("Peeks at timestamp with multiple packets", func(t *testing.T) { + jb := New() + for i := 0; i < 50; i++ { + sqnum := uint16((math.MaxUint16 - 32 + i) % math.MaxUint16) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: sqnum, Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}) + } + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1019, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + jb.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: 1020, Timestamp: uint32(9000)}, Payload: []byte{0x02}}) + assert.Equal(jb.packets.Length(), uint16(52)) + assert.Equal(jb.state, Emitting) + head, err := jb.PeekAtSequence(uint16(1019)) + assert.Equal(head.SequenceNumber, uint16(1019)) + assert.Equal(err, nil) + head, err = jb.PeekAtSequence(uint16(1020)) + assert.Equal(head.SequenceNumber, uint16(1020)) + assert.Equal(err, nil) + + head, err = jb.PopAtSequence(uint16(math.MaxUint16 - 32)) + assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32)) + assert.Equal(err, nil) + }) }