diff --git a/internal/test/mock_stream.go b/internal/test/mock_stream.go index bf96e31b..dd545498 100644 --- a/internal/test/mock_stream.go +++ b/internal/test/mock_stream.go @@ -92,7 +92,6 @@ func NewMockStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Moc if !ok { return 0, nil, io.EOF } - marshaled, err := p.Marshal() if err != nil { return 0, nil, io.EOF diff --git a/pkg/jitterbuffer/jitter_buffer.go b/pkg/jitterbuffer/jitter_buffer.go index 863a0a80..45fd278d 100644 --- a/pkg/jitterbuffer/jitter_buffer.go +++ b/pkg/jitterbuffer/jitter_buffer.go @@ -64,7 +64,7 @@ type ( // order, and allows removing in either sequence number order or via a // provided timestamp type JitterBuffer struct { - packets *PriorityQueue + packets *RBTree minStartCount uint16 lastSequence uint16 playoutHead uint16 @@ -128,6 +128,12 @@ func (jb *JitterBuffer) PlayoutHead() uint16 { return jb.playoutHead } +func (jb *JitterBuffer) Length() uint16 { + jb.mutex.Lock() + defer jb.mutex.Unlock() + return jb.packets.Length() +} + // SetPlayoutHead allows you to manually specify the packet you wish to pop next // If you have encountered a packet that hasn't resolved you can skip it func (jb *JitterBuffer) SetPlayoutHead(playoutHead uint16) { @@ -155,7 +161,7 @@ func (jb *JitterBuffer) Push(packet *rtp.Packet) { if jb.packets.Length() == 0 { jb.emit(StartBuffering) } - if jb.packets.Length() > 100 { + if jb.packets.Length() > 2*jb.minStartCount { jb.stats.overflowCount++ jb.emit(BufferOverflow) } @@ -240,8 +246,6 @@ func (jb *JitterBuffer) PopAtSequence(sq uint16) (*rtp.Packet, error) { // PeekAtSequence will return an RTP packet from the jitter buffer at the specified Sequence // without removing it from the buffer func (jb *JitterBuffer) PeekAtSequence(sq uint16) (*rtp.Packet, error) { - jb.mutex.Lock() - defer jb.mutex.Unlock() packet, err := jb.packets.Find(sq) if err != nil { return nil, err diff --git a/pkg/jitterbuffer/jitter_buffer_test.go b/pkg/jitterbuffer/jitter_buffer_test.go index f163a8f3..74b99ad9 100644 --- a/pkg/jitterbuffer/jitter_buffer_test.go +++ b/pkg/jitterbuffer/jitter_buffer_test.go @@ -112,6 +112,7 @@ func TestJitterBuffer(t *testing.T) { assert.Equal(jb.packets.Length(), uint16(100)) assert.Equal(jb.state, Emitting) head, err := jb.PopAtTimestamp(uint32(513)) + assert.NotNil(head) assert.Equal(head.SequenceNumber, uint16(math.MaxUint16-32+1)) assert.Equal(err, nil) head, err = jb.PopAtTimestamp(uint32(513)) diff --git a/pkg/jitterbuffer/option.go b/pkg/jitterbuffer/option.go index 9a33c22e..b22a3f41 100644 --- a/pkg/jitterbuffer/option.go +++ b/pkg/jitterbuffer/option.go @@ -17,3 +17,10 @@ func Log(log logging.LeveledLogger) ReceiverInterceptorOption { return nil } } + +func WithSkipMissingPackets() ReceiverInterceptorOption { + return func(d *ReceiverInterceptor) error { + d.skipMissingPackets = true + return nil + } +} diff --git a/pkg/jitterbuffer/priority_queue.go b/pkg/jitterbuffer/priority_queue.go index 11a8679c..34b380cd 100644 --- a/pkg/jitterbuffer/priority_queue.go +++ b/pkg/jitterbuffer/priority_queue.go @@ -1,194 +1,468 @@ -// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-FileCopyblackText: 2023 The Pion community // SPDX-License-Identifier: MIT package jitterbuffer import ( "errors" + "fmt" "github.com/pion/rtp" ) -// PriorityQueue provides a linked list sorting of RTP packets by SequenceNumber -type PriorityQueue struct { - next *node +type treeColor bool +type direction int8 + +const ( + red, black treeColor = false, true +) +const ( + left, right direction = 0, 1 +) + +type RBNode struct { + parent, left, right *RBNode + priority uint16 + val *rtp.Packet + color treeColor +} + +type RBTree struct { + root *RBNode length uint16 } -type node struct { - val *rtp.Packet - next *node - prev *node - priority uint16 +func NewTree() *RBTree { + return &RBTree{} } -var ( - // ErrInvalidOperation may be returned if a Pop or Find operation is performed on an empty queue - ErrInvalidOperation = errors.New("attempt to find or pop on an empty list") - // ErrNotFound will be returned if the packet cannot be found in the queue - ErrNotFound = errors.New("priority not found") -) +func (tree *RBTree) RotateRight(x *RBNode) error { + if x == nil || x.left == nil { + return errors.New("RotateRight: node or left child is nil") + } -// NewQueue will create a new PriorityQueue whose order relies on monotonically -// increasing Sequence Number, wrapping at MaxUint16, so -// a packet with sequence number MaxUint16 - 1 will be after 0 -func NewQueue() *PriorityQueue { - return &PriorityQueue{ - next: nil, - length: 0, + y := x.left + x.left = y.right + if y.right != nil { + y.right.parent = x } + + y.parent = x.parent + if x.parent == nil { + tree.root = y + } else if x == x.parent.right { + x.parent.right = y + } else { + x.parent.left = y + } + + y.right = x + x.parent = y + + return nil } -func newNode(val *rtp.Packet, priority uint16) *node { - return &node{ - val: val, - prev: nil, - next: nil, - priority: priority, +func (tree *RBTree) RotateLeft(x *RBNode) error { + if x == nil || x.right == nil { + return errors.New("RotateLeft: node or right child is nil") + } + + y := x.right + x.right = y.left + if y.left != nil { + y.left.parent = x } + + y.parent = x.parent + if x.parent == nil { + tree.root = y + } else if x == x.parent.left { + x.parent.left = y + } else { + x.parent.right = y + } + + y.left = x + x.parent = y + + return nil } -// Find a packet in the queue with the provided sequence number, -// regardless of position (the packet is retained in the queue) -func (q *PriorityQueue) Find(sqNum uint16) (*rtp.Packet, error) { - next := q.next - for next != nil { - if next.priority == sqNum { - return next.val, nil +func (t *RBTree) Insert(pkt *rtp.Packet) { + n := &RBNode{ + val: pkt, + priority: pkt.SequenceNumber, + color: red, + } + t.length++ + if t.root == nil { + t.root = n + t.root.color = black + return + } + + current := t.root + var parent *RBNode + for current != nil { + parent = current + if n.priority < current.priority { + current = current.left + } else { + current = current.right } - next = next.next } - return nil, ErrNotFound + n.parent = parent + if n.priority < parent.priority { + parent.left = n + } else { + parent.right = n + } + + t.fixInsert(n) } -// Push will insert a packet in to the queue in order of sequence number -func (q *PriorityQueue) Push(val *rtp.Packet, priority uint16) { - newPq := newNode(val, priority) - if q.next == nil { - q.next = newPq - q.length++ - return +func (t *RBTree) fixInsert(n *RBNode) { + n.color = red + + for n != t.root && n.parent != nil && n.parent.color == red { + if n.parent.parent == nil { + break + } + + grandparent := n.parent.parent + isParentLeft := n.parent == grandparent.left + uncle := grandparent.right + if !isParentLeft { + uncle = grandparent.left + } + + if uncle != nil && uncle.color == red { + uncle.color = black + n.parent.color = black + if grandparent != t.root { + grandparent.color = red + } + n = grandparent + continue + } + + if isParentLeft { + if n == n.parent.right { + t.RotateLeft(n.parent) + n = n.left + } + n.parent.color = black + grandparent.color = red + t.RotateRight(grandparent) + } else { + if n == n.parent.left { + t.RotateRight(n.parent) + n = n.right + } + n.parent.color = black + grandparent.color = red + t.RotateLeft(grandparent) + } + } + + t.root.color = black +} + +func (t *RBTree) Find(priority uint16) (*rtp.Packet, error) { + current := t.root + for current != nil { + if priority == current.priority { + return current.val, nil + } + if priority < current.priority { + current = current.left + } else { + current = current.right + } } - if priority < q.next.priority { - newPq.next = q.next - q.next.prev = newPq - q.next = newPq - q.length++ + return nil, ErrNotFound +} + +func (t *RBTree) PrettyPrint() { + if t.root == nil { + fmt.Println("Empty tree") return } - head := q.next - prev := q.next - for head != nil { - if priority <= head.priority { - break + + // Helper function to get node details + nodeInfo := func(n *RBNode) string { + if n == nil { + return "NIL(B)" } - prev = head - head = head.next + color := "R" + if n.color == black { + color = "B" + } + return fmt.Sprintf("%d(%s)", n.priority, color) } - if head == nil { - if prev != nil { - prev.next = newPq + + var printNode func(node *RBNode, prefix string, isLeft bool) + printNode = func(node *RBNode, prefix string, isLeft bool) { + if node == nil { + return } - newPq.prev = prev - } else { - newPq.next = head - newPq.prev = prev - if prev != nil { - prev.next = newPq + + nodePrefix := "└──" + childPrefix := " " + if isLeft { + nodePrefix = "├──" + childPrefix = "│ " } - head.prev = newPq + + fmt.Printf("%s%s%s\n", prefix, nodePrefix, nodeInfo(node)) + printNode(node.left, prefix+childPrefix, true) + printNode(node.right, prefix+childPrefix, false) } - q.length++ + + fmt.Printf("%s\n", nodeInfo(t.root)) + printNode(t.root.left, "", true) + printNode(t.root.right, "", false) +} + +var ( + // ErrInvalidOperation may be returned if a Pop or Find operation is performed on an empty queue + ErrInvalidOperation = errors.New("attempt to find or pop on an empty list") + // ErrNotFound will be returned if the packet cannot be found in the queue + ErrNotFound = errors.New("priority not found") +) + +// NewQueue will create a new PriorityQueue whose order relies on monotonically +// increasing Sequence Number, wrapping at MaxUint16, so +// a packet with sequence number MaxUint16 - 1 will be after 0 +func NewQueue() *RBTree { + return &RBTree{} +} + +// Push will insert a packet in to the queue in order of sequence number +func (q *RBTree) Push(val *rtp.Packet, priority uint16) { + q.Insert(val) } // Length will get the total length of the queue -func (q *PriorityQueue) Length() uint16 { +func (q *RBTree) Length() uint16 { return q.length } // Pop removes the first element from the queue, regardless // sequence number -func (q *PriorityQueue) Pop() (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation +func (q *RBTree) Pop() (*rtp.Packet, error) { + node := q.root + err := q.Delete(node.priority) + if err != nil { + return nil, err } - val := q.next.val - q.next.val = nil - q.length-- - q.next = q.next.next - return val, nil + return node.val, nil } // PopAt removes an element at the specified sequence number (priority) -func (q *PriorityQueue) PopAt(sqNum uint16) (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation - } - if q.next.priority == sqNum { - val := q.next.val - q.next.val = nil - q.next = q.next.next - q.length-- - return val, nil - } - pos := q.next - prev := q.next.prev - for pos != nil { - if pos.priority == sqNum { - val := pos.val - pos.val = nil - prev.next = pos.next - if prev.next != nil { - prev.next.prev = prev - } - q.length-- - return val, nil - } - prev = pos - pos = pos.next +func (q *RBTree) PopAt(sqNum uint16) (*rtp.Packet, error) { + pkt, err := q.Find(sqNum) + if err != nil { + return nil, err } - return nil, ErrNotFound + + err = q.Delete(sqNum) + if err != nil { + return nil, err + } + return pkt, nil } // PopAtTimestamp removes and returns a packet at the given RTP Timestamp, regardless // sequence number order -func (q *PriorityQueue) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { - if q.next == nil { - return nil, ErrInvalidOperation - } - if q.next.val.Timestamp == timestamp { - val := q.next.val - q.next.val = nil - q.next = q.next.next - q.length-- - return val, nil - } - pos := q.next - prev := q.next.prev - for pos != nil { - if pos.val.Timestamp == timestamp { - val := pos.val - pos.val = nil - prev.next = pos.next - if prev.next != nil { - prev.next.prev = prev +func (q *RBTree) PopAtTimestamp(timestamp uint32) (*rtp.Packet, error) { + if q.root == nil { + return nil, ErrNotFound + } + + queue := []*RBNode{q.root} + + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + + if node.val.Timestamp == timestamp { + pkt := node.val + err := q.Delete(node.priority) + if err != nil { + return nil, err } - q.length-- - return val, nil + return pkt, nil + } + + if node.left != nil { + queue = append(queue, node.left) + } + if node.right != nil { + queue = append(queue, node.right) } - prev = pos - pos = pos.next } + return nil, ErrNotFound } // Clear will empty a PriorityQueue -func (q *PriorityQueue) Clear() { - next := q.next +func (q *RBTree) Clear() { + q.clear(q.root) + q.root = nil q.length = 0 - for next != nil { - next.prev = nil - next = next.next +} +func (q *RBTree) clear(n *RBNode) { + if n == nil { + return + } + q.clear(n.left) + q.clear(n.right) + n = nil +} + +// Find a node by priority +func (t *RBTree) Peek(priority uint16) (*rtp.Packet, error) { + return t.Find(priority) +} + +// Delete removes a node with the given priority from the tree +func (t *RBTree) Delete(priority uint16) error { + node := t.root + for node != nil && node.priority != priority { + if priority < node.priority { + node = node.left + } else { + node = node.right + } + } + if node == nil { + return ErrNotFound + } + t.length-- + // y is the node to be removed from the tree + // If node has less than 2 children, y = node + // If node has 2 children, y = successor + var y *RBNode + var x *RBNode // x is y's only child (or nil) + + if node.left == nil || node.right == nil { + y = node + } else { + // Find successor (smallest value in right subtree) + y = node.right + for y.left != nil { + y = y.left + } + } + + if y.left != nil { + x = y.left + } else { + x = y.right + } + + if x != nil { + x.parent = y.parent + } + if y.parent == nil { + t.root = x + } else if y == y.parent.left { + y.parent.left = x + } else { + y.parent.right = x + } + + // If we removed the successor, copy its data to the original node + if y != node { + node.priority = y.priority + node.val = y.val + } + + // If we removed a black node, we need to fix the tree + if y.color == black { + t.fixDelete(x, y.parent) + } + + return nil +} + +func (t *RBTree) fixDelete(x *RBNode, parent *RBNode) { + for x != t.root && (x == nil || x.color == black) { + if x == nil && parent == nil { + break + } + + var w *RBNode + isLeft := x == parent.left + if isLeft { + w = parent.right + } else { + w = parent.left + } + + if w == nil { + x = parent + parent = parent.parent + continue + } + + if w.color == red { + w.color = black + parent.color = red + if isLeft { + t.RotateLeft(parent) + w = parent.right + } else { + t.RotateRight(parent) + w = parent.left + } + } + + wLeftBlack := w.left == nil || w.left.color == black + wRightBlack := w.right == nil || w.right.color == black + + if wLeftBlack && wRightBlack { + w.color = red + x = parent + parent = parent.parent + } else { + if isLeft { + if wRightBlack { + if w.left != nil { + w.left.color = black + } + w.color = red + t.RotateRight(w) + w = parent.right + } + w.color = parent.color + parent.color = black + if w.right != nil { + w.right.color = black + } + t.RotateLeft(parent) + } else { + if wLeftBlack { + if w.right != nil { + w.right.color = black + } + w.color = red + t.RotateLeft(w) + w = parent.left + } + w.color = parent.color + parent.color = black + if w.left != nil { + w.left.color = black + } + t.RotateRight(parent) + } + x = t.root + } + } + if x != nil { + x.color = black } } diff --git a/pkg/jitterbuffer/priority_queue_test.go b/pkg/jitterbuffer/priority_queue_test.go index 8b8d23e1..ccb9e1e6 100644 --- a/pkg/jitterbuffer/priority_queue_test.go +++ b/pkg/jitterbuffer/priority_queue_test.go @@ -4,6 +4,7 @@ package jitterbuffer import ( + "fmt" "runtime" "sync/atomic" "testing" @@ -13,172 +14,516 @@ import ( "github.com/stretchr/testify/assert" ) +func TestRotations(t *testing.T) { + t.Run("RotateLeft", func(t *testing.T) { + tree := NewTree() + + // Create a simple tree: + // 5 + // \ + // 7 + // \ + // 9 + root := &RBNode{priority: 5, color: black} + right := &RBNode{priority: 7, color: red} + rightRight := &RBNode{priority: 9, color: red} + + tree.root = root + root.right = right + right.parent = root + right.right = rightRight + rightRight.parent = right + + // After rotating left around 5: + // 7 + // / \ + // 5 9 + err := tree.RotateLeft(root) + assert.NoError(t, err) + + // Verify structure + assert.Equal(t, uint16(7), tree.root.priority) + assert.Equal(t, uint16(5), tree.root.left.priority) + assert.Equal(t, uint16(9), tree.root.right.priority) + + // Verify parent pointers + assert.Nil(t, tree.root.parent) + assert.Equal(t, tree.root, tree.root.left.parent) + assert.Equal(t, tree.root, tree.root.right.parent) + }) + + t.Run("RotateRight", func(t *testing.T) { + tree := NewTree() + + // Create a simple tree: + // 7 + // / + // 5 + // / + // 3 + root := &RBNode{priority: 7, color: black} + left := &RBNode{priority: 5, color: red} + leftLeft := &RBNode{priority: 3, color: red} + + tree.root = root + root.left = left + left.parent = root + left.left = leftLeft + leftLeft.parent = left + + // After rotating right around 7: + // 5 + // / \ + // 3 7 + err := tree.RotateRight(root) + assert.NoError(t, err) + + // Verify structure + assert.Equal(t, uint16(5), tree.root.priority) + assert.Equal(t, uint16(3), tree.root.left.priority) + assert.Equal(t, uint16(7), tree.root.right.priority) + + // Verify parent pointers + assert.Nil(t, tree.root.parent) + assert.Equal(t, tree.root, tree.root.left.parent) + assert.Equal(t, tree.root, tree.root.right.parent) + }) + + t.Run("RotateLeft error cases", func(t *testing.T) { + tree := NewTree() + root := &RBNode{priority: 5, color: black} + tree.root = root + + // Should error when trying to rotate with no right child + err := tree.RotateLeft(root) + assert.Error(t, err) + }) + + t.Run("RotateRight error cases", func(t *testing.T) { + tree := NewTree() + root := &RBNode{priority: 5, color: black} + tree.root = root + + // Should error when trying to rotate with no left child + err := tree.RotateRight(root) + assert.Error(t, err) + }) +} + func TestPriorityQueue(t *testing.T) { assert := assert.New(t) + tree := NewTree() + + t.Run("RotateRight", func(t *testing.T) { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5001, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5003, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5005, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5006, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5007, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5008, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5009, Timestamp: 500}, Payload: []byte{0x02}}) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5010, Timestamp: 500}, Payload: []byte{0x02}}) - t.Run("Appends packets in order", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - assert.Equal(q.next.next.val, pkt2) - assert.Equal(q.next.priority, uint16(5000)) - assert.Equal(q.next.next.priority, uint16(5004)) + assert.Equal(tree.root.priority, uint16(5004)) + // Verify tree maintains Red-Black properties + assert.NoError(validateRBProperties(t, tree)) + // Verify all the elements inserted are in the tree + for _, v := range []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} { + packet, err := tree.Peek(v) + assert.NoError(err) + assert.Equal(v, packet.SequenceNumber) + } }) +} - t.Run("Appends many in order", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) +func TestRBTreeProperties(t *testing.T) { + t.Run("Tree Properties", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} + + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) } - assert.Equal(uint16(100), q.Length()) - last := (*node)(nil) - cur := q.next - for cur != nil { - last = cur - cur = cur.next - if cur != nil { - assert.Equal(cur.priority, last.priority+1) + + // Property 1: Root must be black + assert.Equal(t, black, tree.root.color, "Root must be black") + + // Property 2: Red nodes must have black children + var checkRedNodes func(*RBNode) bool + checkRedNodes = func(n *RBNode) bool { + if n == nil { + return true + } + if n.color == red { + if n.left != nil && n.left.color == red { + t.Errorf("Red node %d has red left child %d", n.priority, n.left.priority) + return false + } + if n.right != nil && n.right.color == red { + t.Errorf("Red node %d has red right child %d", n.priority, n.right.priority) + return false + } } + return checkRedNodes(n.left) && checkRedNodes(n.right) } - assert.Equal(q.next.priority, uint16(5012)) - assert.Equal(last.priority, uint16(5012+99)) + + checkRedNodes(tree.root) }) - t.Run("Can remove an element", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + t.Run("Black Path Property", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) } - popped, _ := q.Pop() - assert.Equal(popped.SequenceNumber, uint16(5000)) - _, _ = q.Pop() - nextPop, _ := q.Pop() - assert.Equal(nextPop.SequenceNumber, uint16(5012)) - }) - t.Run("Appends in order", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + // Count black nodes in path to each leaf + var blackPathLength func(*RBNode, int) []int + blackPathLength = func(n *RBNode, blackCount int) []int { + if n == nil { + return []int{blackCount} + } + if n.color == black { + blackCount++ + } + leftPaths := blackPathLength(n.left, blackCount) + rightPaths := blackPathLength(n.right, blackCount) + return append(leftPaths, rightPaths...) } - assert.Equal(uint16(100), q.Length()) - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt, pkt.SequenceNumber) - assert.Equal(pkt, q.next.val) - assert.Equal(uint16(101), q.Length()) - assert.Equal(q.next.priority, uint16(5000)) - }) - t.Run("Can find", func(*testing.T) { - q := NewQueue() - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + paths := blackPathLength(tree.root, 0) + firstPath := paths[0] + for i, pathLen := range paths { + if pathLen != firstPath { + t.Errorf("Unequal black path lengths: path %d has %d black nodes, expected %d", + i, pathLen, firstPath) + } } - pkt, err := q.Find(5012) - assert.Equal(pkt.SequenceNumber, uint16(5012)) - assert.Equal(err, nil) }) - t.Run("Updates the length when PopAt* are called", func(*testing.T) { - pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5000, Timestamp: 500}, Payload: []byte{0x02}} - q := NewQueue() - q.Push(pkt, pkt.SequenceNumber) - pkt2 := &rtp.Packet{Header: rtp.Header{SequenceNumber: 5004, Timestamp: 500}, Payload: []byte{0x02}} - q.Push(pkt2, pkt2.SequenceNumber) - for i := 0; i < 100; i++ { - q.Push(&rtp.Packet{Header: rtp.Header{SequenceNumber: uint16(5012 + i), Timestamp: uint32(512 + i)}, Payload: []byte{0x02}}, uint16(5012+i)) + t.Run("Black Node Children", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + var checkRedNodes func(*RBNode) bool + checkRedNodes = func(n *RBNode) bool { + if n == nil { + return true + } + if n.color == red { + // Check that red nodes have black children + if n.left != nil && n.left.color == red { + t.Errorf("Red node %d has red left child %d", n.priority, n.left.priority) + return false + } + if n.right != nil && n.right.color == red { + t.Errorf("Red node %d has red right child %d", n.priority, n.right.priority) + return false + } + } + return checkRedNodes(n.left) && checkRedNodes(n.right) } - assert.Equal(uint16(102), q.Length()) - popped, _ := q.PopAt(uint16(5012)) - assert.Equal(popped.SequenceNumber, uint16(5012)) - assert.Equal(uint16(101), q.Length()) - - popped, err := q.PopAtTimestamp(uint32(500)) - assert.Equal(popped.SequenceNumber, uint16(5000)) - assert.Equal(uint16(100), q.Length()) - assert.Equal(err, nil) + + checkRedNodes(tree.root) }) } -func TestPriorityQueue_Find(t *testing.T) { - packets := NewQueue() +func TestTreeStructure(t *testing.T) { + t.Run("Sequential Insertion Structure", func(t *testing.T) { + tree := NewTree() - packets.Push(&rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 1000, - Timestamp: 5, - SSRC: 5, - }, - Payload: []uint8{0xA}, - }, 1000) + // Insert first node (becomes root) + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5004}}) - _, err := packets.PopAt(1000) - assert.NoError(t, err) + // Verify initial state + assert.Equal(t, uint16(5004), tree.root.priority) + assert.Equal(t, black, tree.root.color) - _, err = packets.Find(1001) - assert.Error(t, err) + // Insert second node + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5002}}) + + // Verify after second insertion + assert.Equal(t, uint16(5004), tree.root.priority) + assert.Equal(t, black, tree.root.color) + assert.Equal(t, uint16(5002), tree.root.left.priority) + assert.Equal(t, red, tree.root.left.color) + + // Insert third node + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5000}}) + + // Verify final state after rebalancing + assert.Equal(t, uint16(5002), tree.root.priority) + assert.Equal(t, black, tree.root.color) + assert.Equal(t, uint16(5000), tree.root.left.priority) + assert.Equal(t, red, tree.root.left.color) + assert.Equal(t, uint16(5004), tree.root.right.priority) + assert.Equal(t, red, tree.root.right.color) + }) } -func TestPriorityQueue_Clean(t *testing.T) { - packets := NewQueue() - packets.Clear() - packets.Push(&rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: 1000, - Timestamp: 5, - SSRC: 5, - }, - Payload: []uint8{0xA}, - }, 1000) - assert.EqualValues(t, 1, packets.Length()) - packets.Clear() +func TestBalancedStructure(t *testing.T) { + t.Run("Sequential Right-Side Insertions", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002, 5001, 5003, 5005, 5006, 5007, 5008, 5009, 5010} + + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + // Check Red-Black tree properties + var verifyProperties func(*RBNode) (int, bool) + verifyProperties = func(n *RBNode) (blackHeight int, valid bool) { + if n == nil { + return 0, true + } + + // Property 1: Node is either red or black (implicit in our implementation) + + // Property 2: Red nodes cannot have red children + if n.color == red { + if (n.left != nil && n.left.color == red) || + (n.right != nil && n.right.color == red) { + t.Errorf("Red node %d has a red child", n.priority) + return 0, false + } + } + + leftHeight, leftValid := verifyProperties(n.left) + rightHeight, rightValid := verifyProperties(n.right) + + // Property 3: All paths must have same number of black nodes + if leftHeight != rightHeight { + t.Errorf("Black height mismatch at node %d: left=%d, right=%d", + n.priority, leftHeight, rightHeight) + return 0, false + } + + // Calculate black height of current path + currentBlackHeight := leftHeight + if n.color == black { + currentBlackHeight++ + } + + return currentBlackHeight, leftValid && rightValid + } + + // Property 4: Root must be black + if tree.root.color != black { + t.Error("Root node is not black") + } + + verifyProperties(tree.root) + }) } -func TestPriorityQueue_Unreference(t *testing.T) { - packets := NewQueue() +func TestPeekAndDelete(t *testing.T) { + t.Run("Peek Operations", func(t *testing.T) { + tree := NewTree() + // Test empty tree + _, err := tree.Peek(5000) + assert.Error(t, err, "Peek on empty tree should return error") + + values := []uint16{5004, 5000, 5002, 5001, 5003} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + // Test finding existing values + for _, v := range values { + packet, err := tree.Peek(v) + assert.NoError(t, err) + assert.Equal(t, v, packet.SequenceNumber) + } + + // Test finding non-existent value + _, err = tree.Peek(5999) + assert.Error(t, err, "Peek for non-existent value should return error") + }) + + t.Run("Delete Operations", func(t *testing.T) { + t.Run("Delete Leaf Node", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + err := tree.Delete(5000) + assert.NoError(t, err) + + _, err = tree.Peek(5000) + assert.Error(t, err) + + assert.NoError(t, validateRBProperties(t, tree)) + }) + + t.Run("Delete Node with One Child", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5002, 5000} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + err := tree.Delete(5002) + assert.NoError(t, err) + + assert.Equal(t, uint16(5004), tree.root.priority) + assert.Equal(t, uint16(5000), tree.root.left.priority) + + _, err = tree.Peek(5002) + assert.Error(t, err) + }) + + t.Run("Delete Node with Two Children", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002, 5001, 5003} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + err := tree.Delete(5002) + assert.NoError(t, err) + + assert.NoError(t, validateRBProperties(t, tree)) + + _, err = tree.Peek(5002) + assert.Error(t, err) + }) + + t.Run("Delete Root", func(t *testing.T) { + tree := NewTree() + values := []uint16{5004, 5000, 5002} + for _, v := range values { + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: v}}) + } + + err := tree.Delete(5004) + assert.NoError(t, err) + + assert.Equal(t, uint16(5002), tree.root.priority) + assert.Equal(t, black, tree.root.color) + + _, err = tree.Peek(5004) + assert.Error(t, err) + }) + + t.Run("Delete Non-existent Node", func(t *testing.T) { + tree := NewTree() + tree.Insert(&rtp.Packet{Header: rtp.Header{SequenceNumber: 5004}}) + + err := tree.Delete(5999) + assert.Error(t, err, "Deleting non-existent node should return error") + }) + + t.Run("Delete from Empty Tree", func(t *testing.T) { + tree := NewTree() + err := tree.Delete(5000) + assert.Error(t, err, "Deleting from empty tree should return error") + }) + }) +} + +func TestMemoryLeaks(t *testing.T) { var refs int64 finalizer := func(*rtp.Packet) { atomic.AddInt64(&refs, -1) } - numPkts := 100 - for i := 0; i < numPkts; i++ { - atomic.AddInt64(&refs, 1) - seq := uint16(i) - p := rtp.Packet{ - Header: rtp.Header{ - SequenceNumber: seq, - Timestamp: uint32(i + 42), - }, - Payload: []byte{byte(i)}, + t.Run("Insert and Delete Memory Management", func(t *testing.T) { + tree := NewTree() + const numOperations = 1000 + + for i := uint16(0); i < numOperations; i++ { + pkt := &rtp.Packet{Header: rtp.Header{SequenceNumber: i}} + runtime.SetFinalizer(pkt, finalizer) + atomic.AddInt64(&refs, 1) + tree.Insert(pkt) } - runtime.SetFinalizer(&p, finalizer) - packets.Push(&p, seq) - } - for i := 0; i < numPkts-1; i++ { - switch i % 3 { - case 0: - packets.Pop() //nolint - case 1: - packets.PopAt(uint16(i)) //nolint - case 2: - packets.PopAtTimestamp(uint32(i + 42)) //nolint + + for i := uint16(0); i < numOperations; i++ { + err := tree.Delete(i) + assert.NoError(t, err) } + + runtime.GC() + + time.Sleep(time.Millisecond * 100) + + // Verify all packets were freed + assert.Equal(t, int64(0), atomic.LoadInt64(&refs), + "Memory leak detected: %d packets not freed", atomic.LoadInt64(&refs)) + + assert.Nil(t, tree.root) + }) +} + +func validateRBProperties(t *testing.T, tree *RBTree) error { + if tree.root == nil { + return nil + } + + // Property 1: Root must be black + if tree.root.color != black { + return fmt.Errorf("root is not black") + } + + // Property 2: Red nodes must have black children + // Property 3: All paths must have same number of black nodes + blackHeight, err := validateNode(tree.root, black) + if err != nil { + return err + } + if blackHeight < 0 { + return fmt.Errorf("invalid black height") + } + + return nil +} + +func validateNode(node *RBNode, parentColor treeColor) (int, error) { + if node == nil { + return 0, nil } - runtime.GC() - time.Sleep(10 * time.Millisecond) + // Check red property violation + if node.color == red && parentColor == red { + return -1, fmt.Errorf("red node %d has red parent", node.priority) + } - remainedRefs := atomic.LoadInt64(&refs) - runtime.KeepAlive(packets) + // Check left subtree + leftHeight, err := validateNode(node.left, node.color) + if err != nil { + return -1, err + } - // only the last packet should be still referenced - assert.Equal(t, int64(1), remainedRefs) + // Check right subtree + rightHeight, err := validateNode(node.right, node.color) + if err != nil { + return -1, err + } + + // Verify black height property + if leftHeight != rightHeight { + return -1, fmt.Errorf("unequal black heights at node %d", node.priority) + } + + // Add 1 to black height if current node is black + blackHeight := leftHeight + if node.color == black { + blackHeight++ + } + + return blackHeight, nil +} +func abs(x int) int { + if x < 0 { + return -x + } + return x } diff --git a/pkg/jitterbuffer/receiver_interceptor.go b/pkg/jitterbuffer/receiver_interceptor.go index b4c032b9..4f66bb13 100644 --- a/pkg/jitterbuffer/receiver_interceptor.go +++ b/pkg/jitterbuffer/receiver_interceptor.go @@ -4,6 +4,7 @@ package jitterbuffer import ( + "errors" "sync" "github.com/pion/interceptor" @@ -17,10 +18,13 @@ type InterceptorFactory struct { } // NewInterceptor constructs a new ReceiverInterceptor -func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, error) { +func (g *InterceptorFactory) NewInterceptor(logName string) (interceptor.Interceptor, error) { + if logName == "" { + logName = "jitterbuffer" + } i := &ReceiverInterceptor{ close: make(chan struct{}), - log: logging.NewDefaultLoggerFactory().NewLogger("jitterbuffer"), + log: logging.NewDefaultLoggerFactory().NewLogger(logName), buffer: New(), } @@ -52,11 +56,11 @@ func (g *InterceptorFactory) NewInterceptor(_ string) (interceptor.Interceptor, // arriving) quickly enough. type ReceiverInterceptor struct { interceptor.NoOp - buffer *JitterBuffer - m sync.Mutex - wg sync.WaitGroup - close chan struct{} - log logging.LeveledLogger + buffer *JitterBuffer + wg sync.WaitGroup + close chan struct{} + log logging.LeveledLogger + skipMissingPackets bool } // NewInterceptor returns a new InterceptorFactory @@ -74,19 +78,32 @@ func (i *ReceiverInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader return n, attr, err } packet := &rtp.Packet{} - if err := packet.Unmarshal(buf); err != nil { + if err := packet.Unmarshal(buf[:n]); err != nil { return 0, nil, err } - i.m.Lock() - defer i.m.Unlock() i.buffer.Push(packet) if i.buffer.state == Emitting { - newPkt, err := i.buffer.Pop() - if err != nil { - return 0, nil, err + for { + newPkt, err := i.buffer.Pop() + if err != nil { + if errors.Is(err, ErrNotFound) { + if i.skipMissingPackets { + i.log.Warn("Skipping missing packet") + i.buffer.SetPlayoutHead(i.buffer.PlayoutHead() + 1) + continue + } + } + return 0, nil, err + } + if newPkt != nil { + nlen, err := newPkt.MarshalTo(b) + return nlen, attr, err + + } + if i.buffer.Length() == 0 { + break + } } - nlen, err := newPkt.MarshalTo(b) - return nlen, attr, err } return n, attr, ErrPopWhileBuffering }) @@ -95,16 +112,12 @@ func (i *ReceiverInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader // UnbindRemoteStream is called when the Stream is removed. It can be used to clean up any data related to that track. func (i *ReceiverInterceptor) UnbindRemoteStream(_ *interceptor.StreamInfo) { defer i.wg.Wait() - i.m.Lock() - defer i.m.Unlock() i.buffer.Clear(true) } // Close closes the interceptor func (i *ReceiverInterceptor) Close() error { defer i.wg.Wait() - i.m.Lock() - defer i.m.Unlock() i.buffer.Clear(true) return nil } diff --git a/pkg/jitterbuffer/receiver_interceptor_test.go b/pkg/jitterbuffer/receiver_interceptor_test.go index 58685966..f9f31cea 100644 --- a/pkg/jitterbuffer/receiver_interceptor_test.go +++ b/pkg/jitterbuffer/receiver_interceptor_test.go @@ -80,15 +80,18 @@ func TestReceiverBuffersAndPlaysout(t *testing.T) { SenderSSRC: 123, MediaSSRC: 456, }}) - for s := 0; s < 61; s++ { + for s := 0; s < 910; s++ { stream.ReceiveRTP(&rtp.Packet{Header: rtp.Header{ SequenceNumber: uint16(s), }}) } // Give time for packets to be handled and stream written to. time.Sleep(50 * time.Millisecond) - for s := 0; s < 10; s++ { + for s := 0; s < 50; s++ { read := <-stream.ReadRTP() + if read.Err != nil { + t.Fatal(read.Err) + } seq := read.Packet.Header.SequenceNumber assert.EqualValues(t, uint16(s), seq) } @@ -96,3 +99,53 @@ func TestReceiverBuffersAndPlaysout(t *testing.T) { err = i.Close() assert.NoError(t, err) } + +func TestReceiverBuffersAndPlaysoutSkippingMissingPackets(t *testing.T) { + buf := bytes.Buffer{} + + factory, err := NewInterceptor( + Log(logging.NewDefaultLoggerFactory().NewLogger("test")), + WithSkipMissingPackets(), + ) + assert.NoError(t, err) + + i, err := factory.NewInterceptor("jitterbuffer") + assert.NoError(t, err) + + assert.EqualValues(t, 0, buf.Len()) + + stream := test.NewMockStream(&interceptor.StreamInfo{ + SSRC: 123456, + ClockRate: 90000, + }, i) + + for s := 0; s < 420; s++ { + if s == 6 { + s++ + } + if s == 40 { + s = s + 20 + } + stream.ReceiveRTP(&rtp.Packet{Header: rtp.Header{ + SequenceNumber: uint16(s), + }}) + } + + for s := 0; s < 100; s++ { + read := <-stream.ReadRTP() + if read.Err != nil { + continue + } + seq := read.Packet.Header.SequenceNumber + if s == 6 { + s++ + } + if s == 40 { + s = s + 20 + } + assert.EqualValues(t, uint16(s), seq) + } + assert.NoError(t, stream.Close()) + err = i.Close() + assert.NoError(t, err) +}