From 33e34637ac1cbb84e1c24ae6a935edd481cfda4b Mon Sep 17 00:00:00 2001 From: Mathis Engelbart Date: Sat, 26 Jun 2021 14:05:31 +0200 Subject: [PATCH] Implement first version of loss based GCC First draft of the loss based congestion controller of Google Congestion Control. --- AUTHORS.txt | 1 + internal/types/datarate.go | 33 ++ internal/types/packet_result.go | 20 ++ pkg/cc/feedback_adapter.go | 166 +++++++++ pkg/cc/feedback_adapter_test.go | 579 ++++++++++++++++++++++++++++++++ pkg/cc/interceptor.go | 260 ++++++++++++++ pkg/cc/leaky_bucket_pacer.go | 124 +++++++ pkg/cc/noop_pacer.go | 50 +++ pkg/gcc/delay_based_bwe.go | 13 + pkg/gcc/gcc.go | 3 + pkg/gcc/loss_based_bwe.go | 66 ++++ pkg/gcc/queue.go | 56 +++ pkg/gcc/send_side_bwe.go | 38 +++ 13 files changed, 1409 insertions(+) create mode 100644 internal/types/datarate.go create mode 100644 internal/types/packet_result.go create mode 100644 pkg/cc/feedback_adapter.go create mode 100644 pkg/cc/feedback_adapter_test.go create mode 100644 pkg/cc/interceptor.go create mode 100644 pkg/cc/leaky_bucket_pacer.go create mode 100644 pkg/cc/noop_pacer.go create mode 100644 pkg/gcc/delay_based_bwe.go create mode 100644 pkg/gcc/gcc.go create mode 100644 pkg/gcc/loss_based_bwe.go create mode 100644 pkg/gcc/queue.go create mode 100644 pkg/gcc/send_side_bwe.go diff --git a/AUTHORS.txt b/AUTHORS.txt index 360b8fa6..ea210523 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -8,6 +8,7 @@ adamroach aler9 <46489434+aler9@users.noreply.github.com> Antoine Baché Atsushi Watanabe +boks1971 Jonathan Müller Mathis Engelbart Sean DuBois diff --git a/internal/types/datarate.go b/internal/types/datarate.go new file mode 100644 index 00000000..fb5fdd2d --- /dev/null +++ b/internal/types/datarate.go @@ -0,0 +1,33 @@ +package types + +const ( + // BitPerSecond is a data rate of 1 bit per second + BitPerSecond = DataRate(1) + // KiloBitPerSecond is a data rate of 1 kilobit per second + KiloBitPerSecond = 1000 * BitPerSecond + // MegaBitPerSecond is a data rate of 1 megabit per second + MegaBitPerSecond = 1000 * KiloBitPerSecond +) + +// DataRate in bit per second +type DataRate int + +// BitsPerMillisecond returns the datarate in b/ms (bits per millisecond). +func (r DataRate) BitsPerMillisecond() int { + return int(r / 1000.0) +} + +// MaxDataRate returns the maximum of the given DataRates. +func MaxDataRate(a, b DataRate) DataRate { + if a > b { + return a + } + return b +} + +func MinDataRate(a, b DataRate) DataRate { + if a < b { + return a + } + return b +} diff --git a/internal/types/packet_result.go b/internal/types/packet_result.go new file mode 100644 index 00000000..6b1ff71c --- /dev/null +++ b/internal/types/packet_result.go @@ -0,0 +1,20 @@ +package types + +import ( + "time" + + "github.com/pion/rtp" +) + +type SentPacket struct { + SendTime time.Time + Header *rtp.Header +} + +// PacketResult holds information about a packet and if/when it has been +// sent/received. +type PacketResult struct { + SentPacket SentPacket + ReceiveTime time.Time + Received bool +} diff --git a/pkg/cc/feedback_adapter.go b/pkg/cc/feedback_adapter.go new file mode 100644 index 00000000..a8afe866 --- /dev/null +++ b/pkg/cc/feedback_adapter.go @@ -0,0 +1,166 @@ +package cc + +import ( + "errors" + "sort" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +var errMissingTWCCExtension = errors.New("missing transport layer cc header extension") +var errInvalidFeedbackPacket = errors.New("got invalid feedback packet") + +// TODO(mathis): make types internal only? + +// FeedbackAdapter converts incoming feedback from the wireformat to a +// PacketResult +type FeedbackAdapter struct { + lock sync.Mutex + history map[uint16]types.SentPacket +} + +// NewFeedbackAdapter returns a new FeedbackAdapter +func NewFeedbackAdapter() *FeedbackAdapter { + return &FeedbackAdapter{ + history: make(map[uint16]types.SentPacket), + } +} + +// OnSent records when a packet was been sent. +// TODO(mathis): Is there a better way to get attributes in here? +func (f *FeedbackAdapter) OnSent(ts time.Time, header *rtp.Header, attributes interceptor.Attributes) error { + hdrExtensionID := attributes.Get(twccExtension) + id, ok := hdrExtensionID.(uint8) + if !ok || hdrExtensionID == 0 { + return errMissingTWCCExtension + } + sequenceNumber := header.GetExtension(id) + var tccExt rtp.TransportCCExtension + err := tccExt.Unmarshal(sequenceNumber) + if err != nil { + return err + } + + f.lock.Lock() + defer f.lock.Unlock() + f.history[tccExt.TransportSequence] = types.SentPacket{ + SendTime: ts, + Header: header, + } + return nil +} + +// OnIncomingTransportCC converts the incoming rtcp.TransportLayerCC to a +// []PacketResult +func (f *FeedbackAdapter) OnIncomingTransportCC(feedback *rtcp.TransportLayerCC) ([]types.PacketResult, error) { + f.lock.Lock() + defer f.lock.Unlock() + + result := []types.PacketResult{} + + packetStatusCount := uint16(0) + chunkIndex := 0 + deltaIndex := 0 + referenceTime := time.Time{}.Add(time.Duration(feedback.ReferenceTime) * 64 * time.Millisecond) + + for packetStatusCount < feedback.PacketStatusCount { + if chunkIndex >= len(feedback.PacketChunks) || len(feedback.PacketChunks) == 0 { + return nil, errInvalidFeedbackPacket + } + switch packetChunk := feedback.PacketChunks[chunkIndex].(type) { + case *rtcp.RunLengthChunk: + symbol := packetChunk.PacketStatusSymbol + for i := uint16(0); i < packetChunk.RunLength; i++ { + if sentPacket, ok := f.history[feedback.BaseSequenceNumber+packetStatusCount]; ok { + if symbol == rtcp.TypeTCCPacketReceivedSmallDelta || + symbol == rtcp.TypeTCCPacketReceivedLargeDelta { + if deltaIndex >= len(feedback.RecvDeltas) { + // TODO(mathis): Not enough recv deltas for number + // of received packets: warn or error? + continue + } + receiveTime := getReceiveTime(referenceTime, feedback.RecvDeltas[deltaIndex]) + referenceTime = receiveTime + result = append(result, types.PacketResult{ + SentPacket: sentPacket, + ReceiveTime: receiveTime, + Received: true, + }) + deltaIndex++ + } else { + result = append(result, types.PacketResult{ + SentPacket: sentPacket, + ReceiveTime: time.Time{}, + Received: false, + }) + } + } else { + // TODO(mathis): got feedback for unsent packet? + } + packetStatusCount++ + } + chunkIndex++ + case *rtcp.StatusVectorChunk: + for _, symbol := range packetChunk.SymbolList { + if sentPacket, ok := f.history[feedback.BaseSequenceNumber+packetStatusCount]; ok { + if symbol == rtcp.TypeTCCPacketReceivedSmallDelta || + symbol == rtcp.TypeTCCPacketReceivedLargeDelta { + if deltaIndex >= len(feedback.RecvDeltas) { + // TODO(mathis): Not enough recv deltas for number + // of received packets: warn or error? + continue + } + receiveTime := getReceiveTime(referenceTime, feedback.RecvDeltas[deltaIndex]) + referenceTime = receiveTime + result = append(result, types.PacketResult{ + SentPacket: sentPacket, + ReceiveTime: receiveTime, + Received: true, + }) + deltaIndex++ + } else { + result = append(result, types.PacketResult{ + SentPacket: sentPacket, + ReceiveTime: time.Time{}, + Received: false, + }) + } + } + packetStatusCount++ + if packetStatusCount >= feedback.PacketStatusCount { + break + } + } + chunkIndex++ + } + } + return result, nil +} + +// OnIncomingRFC8888 converts the incoming RFC8888 packet to a []PacketResult +func (f *FeedbackAdapter) OnIncomingRFC8888(feedback *rtcp.RawPacket) ([]types.PacketResult, error) { + return nil, nil +} + +func sortedKeysUint16(m map[uint16]types.SentPacket) []uint16 { + var result []uint16 + for k := range m { + result = append(result, k) + } + sort.Slice(result, func(i, j int) bool { + return result[i] < result[j] + }) + return result +} + +func getReceiveTime(baseTime time.Time, delta *rtcp.RecvDelta) time.Time { + if delta.Type == rtcp.TypeTCCPacketReceivedSmallDelta { + return baseTime.Add(time.Duration(delta.Delta) * 250 * time.Microsecond) + } + return baseTime.Add(time.Duration(delta.Delta) * time.Millisecond) +} diff --git a/pkg/cc/feedback_adapter_test.go b/pkg/cc/feedback_adapter_test.go new file mode 100644 index 00000000..eedfe0ad --- /dev/null +++ b/pkg/cc/feedback_adapter_test.go @@ -0,0 +1,579 @@ +package cc + +import ( + "testing" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/rtcp" + "github.com/pion/rtp" + "github.com/stretchr/testify/assert" +) + +const hdrExtID = uint8(1) + +func getPacketWithTransportCCExt(t *testing.T, SequenceNumber uint16) *rtp.Packet { + pkt := rtp.Packet{ + Header: rtp.Header{}, + Payload: []byte{}, + } + ext := &rtp.TransportCCExtension{ + TransportSequence: SequenceNumber, + } + b, err := ext.Marshal() + assert.NoError(t, err) + assert.NoError(t, pkt.SetExtension(hdrExtID, b)) + return &pkt +} + +func TestFeedbackAdapterTWCC(t *testing.T) { + t.Run("empty", func(t *testing.T) { + adapter := NewFeedbackAdapter() + result, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{}) + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("setsCorrectReceiveTime", func(t *testing.T) { + t0 := time.Time{} + adapter := NewFeedbackAdapter() + headers := []rtp.Header{} + for i := uint16(0); i < 22; i++ { + pkt := getPacketWithTransportCCExt(t, i) + headers = append(headers, pkt.Header) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 22, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + PacketStatusChunk: nil, + Type: rtcp.TypeTCCStatusVectorChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedLargeDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.StatusVectorChunk{ + PacketStatusChunk: nil, + Type: rtcp.TypeTCCStatusVectorChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.RunLengthChunk{ + Type: rtcp.TypeTCCRunLengthChunk, + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 1, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, // 4*250us=1ms + }, + { + Type: rtcp.TypeTCCPacketReceivedLargeDelta, + Delta: 100, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 12, // 3*4*250us=3ms + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + }, + }) + + assert.NoError(t, err) + + assert.NotEmpty(t, results) + assert.Len(t, results, 22) + + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[0], + }, + ReceiveTime: t0.Add(time.Millisecond), + Received: true, + }) + + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[1], + }, + ReceiveTime: t0.Add(101 * time.Millisecond), + Received: true, + }) + + for i := uint16(2); i < 7; i++ { + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[i], + }, + ReceiveTime: time.Time{}, + Received: false, + }) + } + + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[7], + }, + ReceiveTime: t0.Add(104 * time.Millisecond), + Received: true, + }) + + for i := uint16(8); i < 21; i++ { + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[i], + }, + ReceiveTime: time.Time{}, + Received: false, + }) + } + + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[21], + }, + ReceiveTime: t0.Add(105 * time.Millisecond), + Received: true, + }) + }) + + t.Run("doesNotCrashOnTooManyFeedbackReports", func(*testing.T) { + adapter := NewFeedbackAdapter() + assert.NotPanics(t, func() { + _, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 0, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + PacketStatusChunk: nil, + Type: rtcp.TypeTCCStatusVectorChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, // 4*250us=1ms + }, + }, + }) + assert.NoError(t, err) + }) + }) + + t.Run("worksOnSequenceNumberWrapAround", func(t *testing.T) { + t0 := time.Time{} + adapter := NewFeedbackAdapter() + pkt65535 := getPacketWithTransportCCExt(t, 65535) + pkt0 := getPacketWithTransportCCExt(t, 0) + assert.NoError(t, adapter.OnSent(t0, &pkt65535.Header, interceptor.Attributes{twccExtension: hdrExtID})) + assert.NoError(t, adapter.OnSent(t0, &pkt0.Header, interceptor.Attributes{twccExtension: hdrExtID})) + + results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 65535, + PacketStatusCount: 2, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + PacketStatusChunk: nil, + Type: rtcp.TypeTCCStatusVectorChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + }, + }) + assert.NoError(t, err) + + assert.NotEmpty(t, results) + assert.Len(t, results, 2) + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &pkt65535.Header, + }, + ReceiveTime: t0.Add(1 * time.Millisecond), + Received: true, + }) + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &pkt0.Header, + }, + ReceiveTime: t0.Add(2 * time.Millisecond), + Received: true, + }) + }) + + t.Run("ignoresPossiblyInFlightPackets", func(t *testing.T) { + t0 := time.Time{} + adapter := NewFeedbackAdapter() + headers := []rtp.Header{} + for i := uint16(0); i < 8; i++ { + pkt := getPacketWithTransportCCExt(t, i) + headers = append(headers, pkt.Header) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + + results, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 3, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + PacketStatusChunk: nil, + Type: rtcp.TypeTCCStatusVectorChunk, + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, // 4*250us=1ms + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, // 4*250us=1ms + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, // 4*250us=1ms + }, + }, + }) + assert.NoError(t, err) + assert.Len(t, results, 3) + for i := uint16(0); i < 3; i++ { + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: &headers[i], + }, + ReceiveTime: t0.Add(time.Duration(i+1) * time.Millisecond), + Received: true, + }) + } + }) + + t.Run("runLengthChunk", func(t *testing.T) { + adapter := NewFeedbackAdapter() + t0 := time.Time{} + for i := uint16(0); i < 20; i++ { + pkt := getPacketWithTransportCCExt(t, i) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 3, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + }, + }) + + assert.NoError(t, err) + assert.Len(t, packets, 3) + }) + + t.Run("statusVectorChunk", func(t *testing.T) { + adapter := NewFeedbackAdapter() + t0 := time.Time{} + for i := uint16(0); i < 20; i++ { + pkt := getPacketWithTransportCCExt(t, i) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 3, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeOneBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + }, + }) + + assert.NoError(t, err) + assert.Len(t, packets, 3) + }) + + t.Run("mixedRunLengthAndStatusVector", func(t *testing.T) { + adapter := NewFeedbackAdapter() + + t0 := time.Time{} + for i := uint16(0); i < 20; i++ { + pkt := getPacketWithTransportCCExt(t, i) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + + packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 0, + PacketStatusCount: 10, + ReferenceTime: 0, + FbPktCount: 0, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 3, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 4, + }, + }, + }) + assert.NoError(t, err) + assert.Len(t, packets, 10) + }) + + t.Run("doesNotcrashOnInvalidTWCCPacket", func(t *testing.T) { + + adapter := NewFeedbackAdapter() + + t0 := time.Time{} + for i := uint16(1008); i < 1030; i++ { + pkt := getPacketWithTransportCCExt(t, i) + assert.NoError(t, adapter.OnSent(t0, &pkt.Header, interceptor.Attributes{twccExtension: hdrExtID})) + } + + assert.NotPanics(t, func() { + // TODO(mathis): Run length seems off, maybe check why TWCC generated this? + packets, err := adapter.OnIncomingTransportCC(&rtcp.TransportLayerCC{ + Header: rtcp.Header{}, + SenderSSRC: 0, + MediaSSRC: 0, + BaseSequenceNumber: 1008, + PacketStatusCount: 8, + ReferenceTime: 278, + FbPktCount: 170, + PacketChunks: []rtcp.PacketStatusChunk{ + &rtcp.StatusVectorChunk{ + SymbolSize: rtcp.TypeTCCSymbolSizeTwoBit, + SymbolList: []uint16{ + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketReceivedSmallDelta, + rtcp.TypeTCCPacketNotReceived, + }, + }, + &rtcp.RunLengthChunk{ + PacketStatusSymbol: rtcp.TypeTCCPacketReceivedSmallDelta, + RunLength: 5632, + }, + }, + RecvDeltas: []*rtcp.RecvDelta{ + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 25000, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 0, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 29500, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 16750, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 23500, + }, + { + Type: rtcp.TypeTCCPacketReceivedSmallDelta, + Delta: 0, + }, + }, + }) + assert.Error(t, err) + assert.Empty(t, packets) + }) + }) +} diff --git a/pkg/cc/interceptor.go b/pkg/cc/interceptor.go new file mode 100644 index 00000000..f0f16f9c --- /dev/null +++ b/pkg/cc/interceptor.go @@ -0,0 +1,260 @@ +// Package cc implements a congestion controller interceptor that can be used +// with different congestion control algorithms. +package cc + +import ( + "errors" + "sync" + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/interceptor/pkg/gcc" + "github.com/pion/logging" + "github.com/pion/rtcp" + "github.com/pion/rtp" +) + +var errInvalidSessionID = errors.New("no bandwidth estimation for session ID") + +const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" + +type ( + headerExtensionKey int + writerKey int +) + +const ( + twccExtension = iota + streamWriter +) + +type PacerFactory func(interceptor.RTPWriter) Pacer + +type Pacer interface { + interceptor.RTPWriter + AddStream(ssrc uint32, writer interceptor.RTPWriter) + SetTargetBitrate(types.DataRate) + Close() error +} + +type BandwidthEstimatorFactory func() BandwidthEstimator + +// BandwidthEstimator is the interface of a bandwidth estimator +type BandwidthEstimator interface { + OnPacketSent(ts time.Time, sizeInBytes int) + OnFeedback([]types.PacketResult) + GetBandwidthEstimation() types.DataRate +} + +type session struct { + i *ControllerInterceptor +} + +type Option func(*ControllerInterceptor) error + +func SetBWE(bwe BandwidthEstimatorFactory) Option { + return func(ci *ControllerInterceptor) error { + ci.BandwidthEstimator = bwe() + return nil + } +} + +type ControllerInterceptorFactory struct { + opts []Option +} + +func GCCFactory() BandwidthEstimator { + return gcc.NewSendSideBandwidthEstimator(150 * types.KiloBitPerSecond) +} + +func NewControllerInterceptor(opts ...Option) (cif *ControllerInterceptorFactory, err error) { + return &ControllerInterceptorFactory{ + opts: opts, + }, nil +} + +// NewInterceptor creates a new ControllerInterceptor +func (f *ControllerInterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { + i := &ControllerInterceptor{ + NoOp: interceptor.NoOp{}, + log: logging.NewDefaultLoggerFactory().NewLogger("cc_interceptor"), + FeedbackAdapter: *NewFeedbackAdapter(), + BandwidthEstimator: GCCFactory(), + pacer: NewLeakyBucketPacer(), + twccFeedbackChan: make(chan twccFeedback), + rfc8888FeedbackChan: make(chan rfc8888Feedback), + incomingPacketChan: make(chan packetWithAttributes), + wg: sync.WaitGroup{}, + close: make(chan struct{}), + } + + for _, opt := range f.opts { + if err := opt(i); err != nil { + return nil, err + } + } + + go i.loop() + + return i, nil +} + +type twccFeedback struct { + ts time.Time + *rtcp.TransportLayerCC +} + +type rfc8888Feedback struct { + ts time.Time + *rtcp.RawPacket // TODO(mathis) change to RFC8888 packet +} + +type packetWithAttributes struct { + header rtp.Header + payload []byte + attributes interceptor.Attributes +} + +// ControllerInterceptor is an interceptor for congestion control/bandwidth +// estimation +type ControllerInterceptor struct { + interceptor.NoOp + + log logging.LeveledLogger + + FeedbackAdapter + BandwidthEstimator + + pacer Pacer + + twccFeedbackChan chan twccFeedback + rfc8888FeedbackChan chan rfc8888Feedback + incomingPacketChan chan packetWithAttributes + + wg sync.WaitGroup + close chan struct{} +} + +// BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might +// change in the future. The returned method will be called once per packet batch. +func (c *ControllerInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { + return interceptor.RTCPReaderFunc(func(buf []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) { + // TODO(mathis): Put receive timestamp in attributes and populate in + // first interceptor + ts := time.Now() + + i, attr, err := reader.Read(buf, attributes) + if err != nil { + return 0, nil, err + } + if attr == nil { + attr = make(interceptor.Attributes) + } + + pkts, err := attr.GetRTCPPackets(buf[:i]) + if err != nil { + return 0, nil, err + } + for _, pkt := range pkts { + switch feedback := pkt.(type) { + case *rtcp.TransportLayerCC: + c.twccFeedbackChan <- twccFeedback{ts, feedback} + case *rtcp.RawPacket: + c.rfc8888FeedbackChan <- rfc8888Feedback{ts, feedback} + } + } + + return i, attr, nil + }) +} + +// BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method +// will be called once per rtp packet. +func (c *ControllerInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { + // TODO(mathis): figure out if we have to start more loops here or create + // dedicated controllers/pacer for each stream here. + + var hdrExtID uint8 + for _, e := range info.RTPHeaderExtensions { + if e.URI == transportCCURI { + hdrExtID = uint8(e.ID) + break + } + } + if hdrExtID == 0 { // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID + return writer + } + + c.pacer.AddStream(info.SSRC, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + + c.OnSent(time.Now(), header, attributes) + + return writer.Write(header, payload, attributes) + })) + + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + if attributes == nil { + attributes = make(interceptor.Attributes) + } + attributes.Set(twccExtension, hdrExtID) + c.incomingPacketChan <- packetWithAttributes{ + header: *header, + payload: payload, + attributes: attributes, + } + + return header.MarshalSize() + len(payload), nil + }) +} + +// TODO(mathis): start loop, figure out when and how often. Probably only once +// and then add streams. +// TODO(mathis): Update bandwidth sometimes... +func (c *ControllerInterceptor) loop() { + for { + select { + case <-c.close: + return + + case pkt := <-c.incomingPacketChan: + c.pacer.Write(&pkt.header, pkt.payload, pkt.attributes) + + case feedback := <-c.twccFeedbackChan: + packetResult, err := c.OnIncomingTransportCC(feedback.TransportLayerCC) + if err != nil { + // TODO(mathis): handle error + } + c.OnFeedback(packetResult) + c.pacer.SetTargetBitrate(c.GetBandwidthEstimation()) + + case feedback := <-c.rfc8888FeedbackChan: + packetResult, err := c.OnIncomingRFC8888(feedback.RawPacket) + if err != nil { + // TODO(mathis): handle error + } + c.OnFeedback(packetResult) + c.pacer.SetTargetBitrate(c.GetBandwidthEstimation()) + } + } +} + +// Close closes the interceptor. +func (c *ControllerInterceptor) Close() error { + defer c.wg.Wait() + + if !c.isClosed() { + close(c.close) + } + + return nil +} + +func (c *ControllerInterceptor) isClosed() bool { + select { + case <-c.close: + return true + default: + return false + } +} diff --git a/pkg/cc/leaky_bucket_pacer.go b/pkg/cc/leaky_bucket_pacer.go new file mode 100644 index 00000000..475d2d39 --- /dev/null +++ b/pkg/cc/leaky_bucket_pacer.go @@ -0,0 +1,124 @@ +package cc + +import ( + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/logging" + "github.com/pion/rtp" +) + +type item struct { + header *rtp.Header + payload []byte + attributes interceptor.Attributes +} + +type LeakyBucketPacer struct { + log logging.LeveledLogger + + targetBitrate types.DataRate + pacingInterval time.Duration + + itemCh chan item + bitrateCh chan types.DataRate + streamCh chan stream + done chan struct{} + + ssrcToWriter map[uint32]interceptor.RTPWriter +} + +func NewLeakyBucketPacer() *LeakyBucketPacer { + p := &LeakyBucketPacer{ + log: logging.NewDefaultLoggerFactory().NewLogger("pacer"), + targetBitrate: types.DataRate(150_000), + pacingInterval: 5 * time.Millisecond, + itemCh: make(chan item), + bitrateCh: make(chan types.DataRate), + streamCh: make(chan stream), + done: make(chan struct{}), + ssrcToWriter: map[uint32]interceptor.RTPWriter{}, + } + go p.Run() + return p +} + +type stream struct { + ssrc uint32 + writer interceptor.RTPWriter +} + +func (p *LeakyBucketPacer) AddStream(ssrc uint32, writer interceptor.RTPWriter) { + p.streamCh <- stream{ + ssrc: ssrc, + writer: writer, + } +} + +func (p *LeakyBucketPacer) SetTargetBitrate(rate types.DataRate) { + p.targetBitrate = rate +} + +func (p *LeakyBucketPacer) Write(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + p.itemCh <- item{ + header: header, + payload: payload, + attributes: attributes, + } + return header.MarshalSize() + len(payload), nil +} + +func min(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func (p *LeakyBucketPacer) Run() { + ticker := time.NewTicker(p.pacingInterval) + + queue := []item{} + + for { + select { + case <-p.done: + return + case rate := <-p.bitrateCh: + p.targetBitrate = rate + case stream := <-p.streamCh: + p.ssrcToWriter[stream.ssrc] = stream.writer + case item := <-p.itemCh: + queue = append(queue, item) + case <-ticker.C: + budget := p.pacingInterval.Milliseconds() * int64(p.targetBitrate.BitsPerMillisecond()) + + for len(queue) != 0 && budget > 0 { + p.log.Infof("pacer budget=%v, len(queue)=%v", budget, len(queue)) + next := queue[0] + queue = queue[1:] + writer, ok := p.ssrcToWriter[next.header.SSRC] + if !ok { + p.log.Infof("no writer found for ssrc: %v", next.header.SSRC) + } + var twcc rtp.TransportCCExtension + ext := next.header.GetExtension(next.header.GetExtensionIDs()[0]) + if err := twcc.Unmarshal(ext); err != nil { + panic(err) + } + p.log.Infof("pacer sending packet %v", twcc.TransportSequence) + n, err := writer.Write(next.header, next.payload, next.attributes) + if err != nil { + p.log.Errorf("failed to write packet: %v", err) + } + budget -= int64(n) + } + } + } +} + +func (p *LeakyBucketPacer) Close() error { + close(p.done) + return nil +} diff --git a/pkg/cc/noop_pacer.go b/pkg/cc/noop_pacer.go new file mode 100644 index 00000000..47bd92e6 --- /dev/null +++ b/pkg/cc/noop_pacer.go @@ -0,0 +1,50 @@ +package cc + +import ( + "errors" + "fmt" + "sync" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/rtp" +) + +var ErrUnknownStream = errors.New("unknown ssrc") + +type NoOpPacer struct { + lock sync.Mutex + ssrcToWriter map[uint32]interceptor.RTPWriter +} + +func NewNoOpPacer() *NoOpPacer { + return &NoOpPacer{ + lock: sync.Mutex{}, + ssrcToWriter: map[uint32]interceptor.RTPWriter{}, + } +} + +func (p *NoOpPacer) AddStream(ssrc uint32, writer interceptor.RTPWriter) { + p.lock.Lock() + defer p.lock.Unlock() + p.ssrcToWriter[ssrc] = writer +} + +func (p *NoOpPacer) SetTargetBitrate(types.DataRate) { + +} + +func (p *NoOpPacer) Write(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + p.lock.Lock() + defer p.lock.Unlock() + + if w, ok := p.ssrcToWriter[header.SSRC]; ok { + return w.Write(header, payload, attributes) + } + + return 0, fmt.Errorf("%w: %v", ErrUnknownStream, header.SSRC) +} + +func (p *NoOpPacer) Close() error { + return nil +} diff --git a/pkg/gcc/delay_based_bwe.go b/pkg/gcc/delay_based_bwe.go new file mode 100644 index 00000000..2e0fd6fa --- /dev/null +++ b/pkg/gcc/delay_based_bwe.go @@ -0,0 +1,13 @@ +package gcc + +import ( + "math" + + "github.com/pion/interceptor/internal/types" +) + +type delayBasedBandwidthEstimator struct{} + +func (e *delayBasedBandwidthEstimator) getEstimate() types.DataRate { + return math.MaxInt +} diff --git a/pkg/gcc/gcc.go b/pkg/gcc/gcc.go new file mode 100644 index 00000000..6801a47e --- /dev/null +++ b/pkg/gcc/gcc.go @@ -0,0 +1,3 @@ +// Package gcc implements Google Congestion Control +// https://datatracker.ietf.org/doc/html/draft-ietf-rmcat-gcc-02 +package gcc diff --git a/pkg/gcc/loss_based_bwe.go b/pkg/gcc/loss_based_bwe.go new file mode 100644 index 00000000..aead942d --- /dev/null +++ b/pkg/gcc/loss_based_bwe.go @@ -0,0 +1,66 @@ +package gcc + +import ( + "time" + + "github.com/pion/interceptor/internal/types" + "github.com/pion/logging" +) + +type lossBasedBandwidthEstimator struct { + bitrate types.DataRate + averageLoss float64 + lastIncrease time.Time + lastDecrease time.Time + inertia float64 + decay float64 + log logging.LeveledLogger +} + +func newLossBasedBWE() *lossBasedBandwidthEstimator { + return &lossBasedBandwidthEstimator{ + inertia: 0.5, + decay: 0.5, + bitrate: 0, + averageLoss: 0, + log: logging.NewDefaultLoggerFactory().NewLogger("gcc_loss_controller"), + } +} + +func (e *lossBasedBandwidthEstimator) getEstimate(wantedRate types.DataRate) types.DataRate { + if e.bitrate <= 0 { + e.bitrate = wantedRate + } + + return e.bitrate +} + +func (e *lossBasedBandwidthEstimator) updateLossStats(results []types.PacketResult) { + + if len(results) == 0 { + return + } + + packetsLost := 0 + for _, p := range results { + if !p.Received { + packetsLost++ + } + } + + lossRatio := float64(packetsLost) / float64(len(results)) + e.averageLoss = e.inertia*lossRatio + e.decay*(1-e.inertia)*e.averageLoss + + e.log.Infof("averageLoss: %v", e.averageLoss) + + // Naive implementation using constants from IETF Draft + // TODO(mathis): Make this more smart and configurable. (Smart here means + // don't decrease too often and such things, see libwebrtc) + if e.averageLoss < 0.02 && time.Since(e.lastIncrease) > 200*time.Millisecond { + e.lastIncrease = time.Now() + e.bitrate = types.DataRate(1.05 * float64(e.bitrate)) + } else if e.averageLoss > 0.1 && time.Since(e.lastDecrease) > 200*time.Millisecond { + e.lastDecrease = time.Now() + e.bitrate = types.DataRate(float64(e.bitrate) * (1 - 0.5*e.averageLoss)) + } +} diff --git a/pkg/gcc/queue.go b/pkg/gcc/queue.go new file mode 100644 index 00000000..3cef6ec6 --- /dev/null +++ b/pkg/gcc/queue.go @@ -0,0 +1,56 @@ +package gcc + +import ( + "sync" + + "github.com/pion/interceptor" + "github.com/pion/rtp" +) + +type packetWithAttributes struct { + packet *rtp.Packet + attributes interceptor.Attributes +} + +func (p *packetWithAttributes) size() int { + return p.packet.MarshalSize() +} + +type packetWithAttributeQueue struct { + data []*packetWithAttributes + mutex sync.RWMutex +} + +func (q *packetWithAttributeQueue) Push(p *packetWithAttributes) { + q.mutex.Lock() + defer q.mutex.Unlock() + q.data = append(q.data, p) +} + +func (q *packetWithAttributeQueue) Pop() *packetWithAttributes { + q.mutex.Lock() + defer q.mutex.Unlock() + + if len(q.data) == 0 { + return nil + } + p := q.data[0] + q.data = q.data[1:] + + return p +} + +func (q *packetWithAttributeQueue) Peek() *packetWithAttributes { + q.mutex.RLock() + defer q.mutex.RUnlock() + + if len(q.data) == 0 { + return nil + } + + return q.data[0] +} + +func (q *packetWithAttributeQueue) Size() int { + return len(q.data) +} diff --git a/pkg/gcc/send_side_bwe.go b/pkg/gcc/send_side_bwe.go new file mode 100644 index 00000000..78b512a3 --- /dev/null +++ b/pkg/gcc/send_side_bwe.go @@ -0,0 +1,38 @@ +package gcc + +import ( + "time" + + "github.com/pion/interceptor/internal/types" +) + +// SendSideBandwidthEstimator implements send side bandwidth estimation +type SendSideBandwidthEstimator struct { + lastBWE types.DataRate + lossBased *lossBasedBandwidthEstimator + delayBased *delayBasedBandwidthEstimator +} + +// NewSendSideBandwidthEstimator returns a new send side bandwidth estimator +// using delay based and loss based bandwidth estimation. +func NewSendSideBandwidthEstimator(initialBitrate types.DataRate) *SendSideBandwidthEstimator { + return &SendSideBandwidthEstimator{ + lastBWE: initialBitrate, + lossBased: newLossBasedBWE(), + delayBased: &delayBasedBandwidthEstimator{}, + } +} + +// OnPacketSent records a packet as sent. +func (g *SendSideBandwidthEstimator) OnPacketSent(ts time.Time, sizeInBytes int) { +} + +// OnFeedback updates the GCC statistics from the incoming feedback. +func (g *SendSideBandwidthEstimator) OnFeedback(feedback []types.PacketResult) { + g.lossBased.updateLossStats(feedback) +} + +// GetBandwidthEstimation returns the estimated bandwidth available +func (g *SendSideBandwidthEstimator) GetBandwidthEstimation() types.DataRate { + return types.MinDataRate(g.delayBased.getEstimate(), g.lossBased.getEstimate(g.lastBWE)) +}