diff --git a/pkg/cc/datarate.go b/internal/types/datarate.go similarity index 86% rename from pkg/cc/datarate.go rename to internal/types/datarate.go index 6a115e5f..fb5fdd2d 100644 --- a/pkg/cc/datarate.go +++ b/internal/types/datarate.go @@ -1,4 +1,4 @@ -package cc +package types const ( // BitPerSecond is a data rate of 1 bit per second @@ -24,3 +24,10 @@ func MaxDataRate(a, b DataRate) DataRate { } return b } + +func MinDataRate(a, b DataRate) DataRate { + if a < b { + return a + } + return b +} diff --git a/internal/types/packet_result.go b/internal/types/packet_result.go new file mode 100644 index 00000000..30f02bc7 --- /dev/null +++ b/internal/types/packet_result.go @@ -0,0 +1,20 @@ +package types + +import ( + "time" + + "github.com/pion/rtp" +) + +type SentPacket struct { + SendTime time.Time + Header rtp.Header +} + +// PacketResult holds information about a packet and if/when it has been +// sent/received. +type PacketResult struct { + SentPacket SentPacket + ReceiveTime time.Time + Received bool +} diff --git a/pkg/cc/feedback_adapter.go b/pkg/cc/feedback_adapter.go index 30ce1dd2..4a7daae6 100644 --- a/pkg/cc/feedback_adapter.go +++ b/pkg/cc/feedback_adapter.go @@ -6,6 +6,7 @@ import ( "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" "github.com/pion/rtcp" "github.com/pion/rtp" ) @@ -14,29 +15,16 @@ var errMissingTWCCExtension = errors.New("missing transport layer cc header exte // TODO(mathis): make types internal only? -type sentPacket struct { - sendTime time.Time - header rtp.Header -} - -// PacketResult holds information about a packet and if/when it has been -// sent/received. -type PacketResult struct { - SentPacket sentPacket - receiveTime time.Time - Received bool -} - // FeedbackAdapter converts incoming feedback from the wireformat to a // PacketResult type FeedbackAdapter struct { - history map[uint16]sentPacket + history map[uint16]types.SentPacket } // NewFeedbackAdapter returns a new FeedbackAdapter func NewFeedbackAdapter() *FeedbackAdapter { return &FeedbackAdapter{ - history: make(map[uint16]sentPacket), + history: make(map[uint16]types.SentPacket), } } @@ -55,17 +43,17 @@ func (f *FeedbackAdapter) OnSent(ts time.Time, pkt *rtp.Packet, attributes inter return err } - f.history[tccExt.TransportSequence] = sentPacket{ - sendTime: ts, - header: pkt.Header, + f.history[tccExt.TransportSequence] = types.SentPacket{ + SendTime: ts, + Header: pkt.Header, } return nil } // OnIncomingTransportCC converts the incoming rtcp.TransportLayerCC to a // []PacketResult -func (f *FeedbackAdapter) OnIncomingTransportCC(ts time.Time, feedback *rtcp.TransportLayerCC) []PacketResult { - result := []PacketResult{} +func (f *FeedbackAdapter) OnIncomingTransportCC(ts time.Time, feedback *rtcp.TransportLayerCC) []types.PacketResult { + result := []types.PacketResult{} baseSequenceNr := feedback.BaseSequenceNumber sequenceNr := baseSequenceNr @@ -109,9 +97,9 @@ func (f *FeedbackAdapter) OnIncomingTransportCC(ts time.Time, feedback *rtcp.Tra receiveTime := referenceTime.Add(delta) referenceTime = receiveTime - result = append(result, PacketResult{ + result = append(result, types.PacketResult{ SentPacket: sent, - receiveTime: receiveTime, + ReceiveTime: receiveTime, Received: true, }) delete(f.history, sequenceNr) @@ -153,9 +141,9 @@ func (f *FeedbackAdapter) OnIncomingTransportCC(ts time.Time, feedback *rtcp.Tra receiveTime := referenceTime.Add(delta) referenceTime = receiveTime - result = append(result, PacketResult{ + result = append(result, types.PacketResult{ SentPacket: sent, - receiveTime: receiveTime, + ReceiveTime: receiveTime, Received: true, }) delete(f.history, sequenceNr) @@ -167,16 +155,16 @@ func (f *FeedbackAdapter) OnIncomingTransportCC(ts time.Time, feedback *rtcp.Tra } for _, v := range sortedKeysUint16(f.history) { - result = append(result, PacketResult{ + result = append(result, types.PacketResult{ SentPacket: f.history[v], - receiveTime: time.Time{}, + ReceiveTime: time.Time{}, Received: false, }) } return result } -func sortedKeysUint16(m map[uint16]sentPacket) []uint16 { +func sortedKeysUint16(m map[uint16]types.SentPacket) []uint16 { var result []uint16 for k := range m { result = append(result, k) @@ -188,6 +176,6 @@ func sortedKeysUint16(m map[uint16]sentPacket) []uint16 { } // OnIncomingRFC8888 converts the incoming RFC8888 packet to a []PacketResult -func (f *FeedbackAdapter) OnIncomingRFC8888(ts time.Time, feedback *rtcp.RawPacket) []PacketResult { +func (f *FeedbackAdapter) OnIncomingRFC8888(ts time.Time, feedback *rtcp.RawPacket) []types.PacketResult { return nil } diff --git a/pkg/cc/feedback_adapter_test.go b/pkg/cc/feedback_adapter_test.go index 677e813a..c6136516 100644 --- a/pkg/cc/feedback_adapter_test.go +++ b/pkg/cc/feedback_adapter_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/stretchr/testify/assert" @@ -39,12 +40,12 @@ func TestFeedbackAdapterTWCC(t *testing.T) { assert.NoError(t, adapter.OnSent(time.Time{}, pkt, interceptor.Attributes{twccExtension: hdrExtID})) result := adapter.OnIncomingTransportCC(time.Time{}, &rtcp.TransportLayerCC{}) assert.NotEmpty(t, result) - assert.Contains(t, result, PacketResult{ - SentPacket: sentPacket{ - sendTime: time.Time{}, - header: pkt.Header, + assert.Contains(t, result, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: time.Time{}, + Header: pkt.Header, }, - receiveTime: time.Time{}, + ReceiveTime: time.Time{}, Received: false, }) }) @@ -131,61 +132,61 @@ func TestFeedbackAdapterTWCC(t *testing.T) { assert.NotEmpty(t, results) assert.Len(t, results, 22) - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[0], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[0], }, - receiveTime: t0.Add(time.Millisecond), + ReceiveTime: t0.Add(time.Millisecond), Received: true, }) - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[1], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[1], }, - receiveTime: t0.Add(101 * time.Millisecond), + ReceiveTime: t0.Add(101 * time.Millisecond), Received: true, }) for i := uint16(2); i < 7; i++ { - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[i], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[i], }, - receiveTime: time.Time{}, + ReceiveTime: time.Time{}, Received: false, }) } - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[7], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[7], }, - receiveTime: t0.Add(104 * time.Millisecond), + ReceiveTime: t0.Add(104 * time.Millisecond), Received: true, }) for i := uint16(8); i < 21; i++ { - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[i], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[i], }, - receiveTime: time.Time{}, + ReceiveTime: time.Time{}, Received: false, }) } - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: headers[21], + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: headers[21], }, - receiveTime: t0.Add(105 * time.Millisecond), + ReceiveTime: t0.Add(105 * time.Millisecond), Received: true, }) }) @@ -274,20 +275,20 @@ func TestFeedbackAdapterTWCC(t *testing.T) { assert.NotEmpty(t, results) assert.Len(t, results, 2) - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: pkt65535.Header, + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: pkt65535.Header, }, - receiveTime: t0.Add(1 * time.Millisecond), + ReceiveTime: t0.Add(1 * time.Millisecond), Received: true, }) - assert.Contains(t, results, PacketResult{ - SentPacket: sentPacket{ - sendTime: t0, - header: pkt0.Header, + assert.Contains(t, results, types.PacketResult{ + SentPacket: types.SentPacket{ + SendTime: t0, + Header: pkt0.Header, }, - receiveTime: t0.Add(2 * time.Millisecond), + ReceiveTime: t0.Add(2 * time.Millisecond), Received: true, }) }) diff --git a/pkg/cc/interceptor.go b/pkg/cc/interceptor.go index 74ed404d..b8dacb93 100644 --- a/pkg/cc/interceptor.go +++ b/pkg/cc/interceptor.go @@ -3,10 +3,13 @@ package cc import ( + "fmt" "sync" "time" "github.com/pion/interceptor" + "github.com/pion/interceptor/internal/types" + "github.com/pion/interceptor/pkg/gcc" "github.com/pion/rtcp" "github.com/pion/rtp" ) @@ -20,26 +23,41 @@ type Pacer interface { // BandwidthEstimator is the interface of a bandwidth estimator type BandwidthEstimator interface { OnPacketSent(ts time.Time, sizeInBytes int) - OnFeedback(time.Time, []PacketResult) - GetBandwidthEstimation(time.Time) DataRate + OnFeedback(time.Time, []types.PacketResult) + GetBandwidthEstimation(time.Time) types.DataRate } // TODO(mathis): implement options? type ControllerInterceptorFactory struct { + i interceptor.Interceptor } // NewInterceptor creates a new ControllerInterceptor func (f *ControllerInterceptorFactory) NewInterceptor(id string) (interceptor.Interceptor, error) { - return &ControllerInterceptor{ + if f.i != nil { + return f.i, nil + } + sendSideBWE, err := gcc.NewSendSideBandwidthEstimator(150 * types.KiloBitPerSecond) + if err != nil { + return nil, err + } + pacer, err := gcc.NewLeakyBucketPacer(150 * types.KiloBitPerSecond) + if err != nil { + return nil, err + } + i := &ControllerInterceptor{ NoOp: interceptor.NoOp{}, FeedbackAdapter: *NewFeedbackAdapter(), - BandwidthEstimator: nil, - Pacer: nil, + BandwidthEstimator: sendSideBWE, + Pacer: pacer, twccFeedbackChan: make(chan twccFeedback), rfc8888FeedbackChan: make(chan rfc8888Feedback), + incomingPacketChan: make(chan packetWithAttributes), wg: sync.WaitGroup{}, close: make(chan struct{}), - }, nil + } + go i.loop() + return i, nil } type twccFeedback struct { @@ -114,9 +132,15 @@ func (c *ControllerInterceptor) BindRTCPReader(reader interceptor.RTCPReader) in const transportCCURI = "http://www.ietf.org/id/draft-holmer-rmcat-transport-wide-cc-extensions-01" -type headerExtensionKey int +type ( + headerExtensionKey int + writerKey int +) -const twccExtension = iota +const ( + twccExtension = iota + streamWriter +) // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method // will be called once per rtp packet. @@ -134,12 +158,14 @@ func (c *ControllerInterceptor) BindLocalStream(info *interceptor.StreamInfo, wr if hdrExtID == 0 { // Don't try to read header extension if ID is 0, because 0 is an invalid extension ID return writer } + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { pkt := &rtp.Packet{ Header: *header, Payload: payload, } attributes.Set(twccExtension, hdrExtID) + attributes.Set(streamWriter, writer) c.incomingPacketChan <- packetWithAttributes{ packet: pkt, attributes: attributes, @@ -152,21 +178,26 @@ func (c *ControllerInterceptor) BindLocalStream(info *interceptor.StreamInfo, wr // TODO(mathis): start loop, figure out when and how often. Probably only once // and then add streams. // TODO(mathis): Update bandwidth sometimes... -func (c *ControllerInterceptor) loop(writer interceptor.RTPWriter) { +func (c *ControllerInterceptor) loop() { ticker := time.NewTicker(5 * time.Millisecond) for { + fmt.Printf("bwe: %v\n", c.GetBandwidthEstimation(time.Now())) select { case <-c.close: return case now := <-ticker.C: - for pkt, attributes := c.GetPendingPacket(now); pkt != nil; { + pkt, attributes := c.GetPendingPacket(now) + for pkt != nil { + writer := attributes.Get(streamWriter).(interceptor.RTPWriter) n, err := writer.Write(&pkt.Header, pkt.Payload, attributes) if err != nil { // TODO(mathis): Handle error + panic(fmt.Errorf("TODO: Handle error: %w", err)) } c.OnPacketSent(now, n) c.OnSent(now, pkt, attributes) + pkt, attributes = c.GetPendingPacket(now) } case pkt := <-c.incomingPacketChan: diff --git a/pkg/gcc/delay_based_bwe.go b/pkg/gcc/delay_based_bwe.go index 4d616a83..2e0fd6fa 100644 --- a/pkg/gcc/delay_based_bwe.go +++ b/pkg/gcc/delay_based_bwe.go @@ -3,11 +3,11 @@ package gcc import ( "math" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" ) type delayBasedBandwidthEstimator struct{} -func (e *delayBasedBandwidthEstimator) getEstimate() cc.DataRate { +func (e *delayBasedBandwidthEstimator) getEstimate() types.DataRate { return math.MaxInt } diff --git a/pkg/gcc/leaky_bucket.go b/pkg/gcc/leaky_bucket.go index 2a226f52..36d88266 100644 --- a/pkg/gcc/leaky_bucket.go +++ b/pkg/gcc/leaky_bucket.go @@ -3,7 +3,7 @@ package gcc import ( "time" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" ) const ( @@ -13,13 +13,13 @@ const ( ) type leakyBucket struct { - targetBitrate cc.DataRate + targetBitrate types.DataRate lastConsume time.Time budget int maxPacingDebt time.Duration } -func newLeakyBucket(initialBitrate cc.DataRate, maxPacingDebt time.Duration) *leakyBucket { +func newLeakyBucket(initialBitrate types.DataRate, maxPacingDebt time.Duration) *leakyBucket { b := &leakyBucket{ targetBitrate: initialBitrate, lastConsume: time.Time{}, diff --git a/pkg/gcc/leaky_bucket_pacer.go b/pkg/gcc/leaky_bucket_pacer.go index ada20286..03d57b30 100644 --- a/pkg/gcc/leaky_bucket_pacer.go +++ b/pkg/gcc/leaky_bucket_pacer.go @@ -4,7 +4,7 @@ import ( "time" "github.com/pion/interceptor" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" "github.com/pion/rtp" ) @@ -16,7 +16,7 @@ type LeakyBucketPacer struct { // NewLeakyBucketPacer creates a new LeakyBucketPacer // TODO(mathis): Add options -func NewLeakyBucketPacer(targetBitrateBps cc.DataRate, opts ...LeakyBucketPacerOption) (*LeakyBucketPacer, error) { +func NewLeakyBucketPacer(targetBitrateBps types.DataRate, opts ...LeakyBucketPacerOption) (*LeakyBucketPacer, error) { lb := newLeakyBucket(targetBitrateBps, 5*time.Millisecond) se := &LeakyBucketPacer{ queue: &packetWithAttributeQueue{}, diff --git a/pkg/gcc/leaky_bucket_pacer_test.go b/pkg/gcc/leaky_bucket_pacer_test.go index 485562c7..c9cfdb25 100644 --- a/pkg/gcc/leaky_bucket_pacer_test.go +++ b/pkg/gcc/leaky_bucket_pacer_test.go @@ -4,14 +4,14 @@ import ( "testing" "time" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" "github.com/pion/rtp" "github.com/stretchr/testify/assert" ) func TestLeakyBucketPacer(t *testing.T) { t.Run("emptyReturnsNil", func(t *testing.T) { - lbp, err := NewLeakyBucketPacer(cc.MegaBitPerSecond) + lbp, err := NewLeakyBucketPacer(types.MegaBitPerSecond) assert.NoError(t, err) pkt, attr := lbp.GetPendingPacket(time.Time{}) assert.Nil(t, pkt) @@ -19,7 +19,7 @@ func TestLeakyBucketPacer(t *testing.T) { }) t.Run("zeroTimeReturnsPacket", func(t *testing.T) { - lbp, err := NewLeakyBucketPacer(cc.KiloBitPerSecond) + lbp, err := NewLeakyBucketPacer(types.KiloBitPerSecond) assert.NoError(t, err) packet := &rtp.Packet{ @@ -38,7 +38,7 @@ func TestLeakyBucketPacer(t *testing.T) { }) t.Run("usesFullBudget", func(t *testing.T) { - lbp, err := NewLeakyBucketPacer(cc.MegaBitPerSecond) + lbp, err := NewLeakyBucketPacer(types.MegaBitPerSecond) assert.NoError(t, err) for i := 0; i < 12; i++ { @@ -78,7 +78,7 @@ func TestLeakyBucketPacer(t *testing.T) { }) t.Run("pacesCorrectly", func(t *testing.T) { - lbp, err := NewLeakyBucketPacer(5 * cc.MegaBitPerSecond) + lbp, err := NewLeakyBucketPacer(5 * types.MegaBitPerSecond) assert.NoError(t, err) for frame := 0; frame < 30; frame++ { diff --git a/pkg/gcc/leaky_bucket_test.go b/pkg/gcc/leaky_bucket_test.go index 2a9b9108..22751a2d 100644 --- a/pkg/gcc/leaky_bucket_test.go +++ b/pkg/gcc/leaky_bucket_test.go @@ -5,22 +5,22 @@ import ( "testing" "time" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" "github.com/stretchr/testify/assert" ) func TestLeakyBucket(t *testing.T) { for _, tc := range []struct { - bitrate cc.DataRate + bitrate types.DataRate maxBytesPerMs int }{ - {cc.KiloBitPerSecond, maxPacketSize}, - {cc.MegaBitPerSecond, maxPacketSize}, - {10 * cc.MegaBitPerSecond, maxPacketSize}, - {100 * cc.MegaBitPerSecond, 12_500}, - {1000 * cc.MegaBitPerSecond, 125_000}, + {types.KiloBitPerSecond, maxPacketSize}, + {types.MegaBitPerSecond, maxPacketSize}, + {10 * types.MegaBitPerSecond, maxPacketSize}, + {100 * types.MegaBitPerSecond, 12_500}, + {1000 * types.MegaBitPerSecond, 125_000}, } { - func(bitrate cc.DataRate, maxBytesPerMs int) { + func(bitrate types.DataRate, maxBytesPerMs int) { t.Run("initialBudget", func(t *testing.T) { lb := newLeakyBucket(bitrate, time.Millisecond) assert.Equal(t, maxBytesPerMs, lb.maxBurstSize()) @@ -30,7 +30,7 @@ func TestLeakyBucket(t *testing.T) { } t.Run("reduceBudgetOnConsume", func(t *testing.T) { - lb := newLeakyBucket(cc.MegaBitPerSecond, time.Millisecond) + lb := newLeakyBucket(types.MegaBitPerSecond, time.Millisecond) t0 := time.Time{} for i := 0; i < 12; i++ { lb.updateBudgetWithElapsedTime(t0) @@ -41,7 +41,7 @@ func TestLeakyBucket(t *testing.T) { }) t.Run("refillAfterElapsedTime", func(t *testing.T) { - lb := newLeakyBucket(100*cc.MegaBitPerSecond, time.Millisecond) + lb := newLeakyBucket(100*types.MegaBitPerSecond, time.Millisecond) t0 := time.Now() assert.Equal(t, 12500, lb.Budget()) lb.updateWithSentData(t0, 12000) @@ -51,16 +51,16 @@ func TestLeakyBucket(t *testing.T) { }) for _, tc := range []struct { - rate cc.DataRate + rate types.DataRate interval time.Duration budgetPerInterval int }{ - {rate: 5 * cc.MegaBitPerSecond, interval: 5 * time.Millisecond, budgetPerInterval: 3125}, - {rate: 1 * cc.MegaBitPerSecond, interval: 10 * time.Millisecond, budgetPerInterval: 1250}, - {rate: 150 * cc.KiloBitPerSecond, interval: 100 * time.Millisecond, budgetPerInterval: 1875}, - {rate: 500 * cc.KiloBitPerSecond, interval: 60 * time.Millisecond, budgetPerInterval: 3750}, + {rate: 5 * types.MegaBitPerSecond, interval: 5 * time.Millisecond, budgetPerInterval: 3125}, + {rate: 1 * types.MegaBitPerSecond, interval: 10 * time.Millisecond, budgetPerInterval: 1250}, + {rate: 150 * types.KiloBitPerSecond, interval: 100 * time.Millisecond, budgetPerInterval: 1875}, + {rate: 500 * types.KiloBitPerSecond, interval: 60 * time.Millisecond, budgetPerInterval: 3750}, } { - func(rate cc.DataRate, interval time.Duration, budgetPerInterval int) { + func(rate types.DataRate, interval time.Duration, budgetPerInterval int) { t.Run(fmt.Sprintf("pacesOut%vbpsTo%vTicks", rate, interval), func(t *testing.T) { lb := newLeakyBucket(rate, interval) diff --git a/pkg/gcc/loss_based_bwe.go b/pkg/gcc/loss_based_bwe.go index e47a34dc..edf19561 100644 --- a/pkg/gcc/loss_based_bwe.go +++ b/pkg/gcc/loss_based_bwe.go @@ -1,10 +1,11 @@ package gcc import ( + "fmt" "math" "time" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" ) type lossBasedBWEConfig struct { @@ -16,13 +17,13 @@ type lossBasedBWEConfig struct { type lossBasedBandwidthEstimator struct { config lossBasedBWEConfig - bitrate cc.DataRate + bitrate types.DataRate averageLoss float64 maxAverageLoss float64 averageLossMax float64 lastLossReport time.Time - maxAcknowledgedRate cc.DataRate + maxAcknowledgedRate types.DataRate lastAcknowledgedRateReport time.Time } @@ -45,24 +46,25 @@ func newLossBasedBWE() *lossBasedBandwidthEstimator { } } -func (e *lossBasedBandwidthEstimator) getEstimate(wantedRate cc.DataRate) cc.DataRate { +func (e *lossBasedBandwidthEstimator) getEstimate(wantedRate types.DataRate) types.DataRate { if e.bitrate == 0 { e.bitrate = wantedRate } + fmt.Printf("maxAverageLoss=%v\n", e.maxAverageLoss) // Naive implementation using constants from IETF Draft // TODO(mathis): Make this more smart and configurable. (Smart here means // don't decrease too often and such things, see libwebrtc) if e.maxAverageLoss < 0.02 { - e.bitrate = cc.DataRate(float64(e.bitrate) * (1 - 0.5*e.maxAverageLoss)) + e.bitrate = types.DataRate(1.05 * float64(e.bitrate)) } else if e.maxAverageLoss > 0.1 { - e.bitrate = cc.DataRate(1.05 * float64(e.bitrate)) + e.bitrate = types.DataRate(float64(e.bitrate) * (1 - 0.5*e.maxAverageLoss)) } return e.bitrate } -func (e *lossBasedBandwidthEstimator) updateLossStats(now time.Time, results []cc.PacketResult) { +func (e *lossBasedBandwidthEstimator) updateLossStats(now time.Time, results []types.PacketResult) { packetsLost := 0 for _, p := range results { if !p.Received { @@ -82,13 +84,13 @@ func (e *lossBasedBandwidthEstimator) updateLossStats(now time.Time, results []c } } -func (e *lossBasedBandwidthEstimator) updateAcknowledgedBitrate(now time.Time, acknowledgedRate cc.DataRate) { +func (e *lossBasedBandwidthEstimator) updateAcknowledgedBitrate(now time.Time, acknowledgedRate types.DataRate) { delta := deltaOrDefault(e.lastAcknowledgedRateReport, now, time.Second) if acknowledgedRate > e.maxAcknowledgedRate { e.maxAcknowledgedRate = acknowledgedRate } else { // TODO(mathis): Double check these type conversions - e.maxAcknowledgedRate -= cc.DataRate(exponentialUpdate(delta, e.config.maxAcknowledgedRateWindow) * float64(e.maxAcknowledgedRate-acknowledgedRate)) + e.maxAcknowledgedRate -= types.DataRate(exponentialUpdate(delta, e.config.maxAcknowledgedRateWindow) * float64(e.maxAcknowledgedRate-acknowledgedRate)) } } diff --git a/pkg/gcc/send_side_bwe.go b/pkg/gcc/send_side_bwe.go index 39cb7d28..628d2e28 100644 --- a/pkg/gcc/send_side_bwe.go +++ b/pkg/gcc/send_side_bwe.go @@ -1,40 +1,38 @@ package gcc import ( - "math" "time" - "github.com/pion/interceptor/pkg/cc" + "github.com/pion/interceptor/internal/types" ) // SendSideBandwidthEstimator implements send side bandwidth estimation type SendSideBandwidthEstimator struct { + lastBWE types.DataRate lossBased *lossBasedBandwidthEstimator delayBased *delayBasedBandwidthEstimator } // NewSendSideBandwidthEstimator returns a new send side bandwidth estimator // using delay based and loss based bandwidth estimation. -func NewSendSideBandwidthEstimator() (*SendSideBandwidthEstimator, error) { +func NewSendSideBandwidthEstimator(initialBitrate types.DataRate) (*SendSideBandwidthEstimator, error) { return &SendSideBandwidthEstimator{ + lastBWE: initialBitrate, lossBased: newLossBasedBWE(), delayBased: &delayBasedBandwidthEstimator{}, }, nil } -// TODO(mathis): remove when BandwidthEstimator interface is stable -var _ cc.BandwidthEstimator = &SendSideBandwidthEstimator{} - // OnPacketSent records a packet as sent. func (g *SendSideBandwidthEstimator) OnPacketSent(ts time.Time, sizeInBytes int) { } // OnFeedback updates the GCC statistics from the incoming feedback. -func (g *SendSideBandwidthEstimator) OnFeedback(ts time.Time, feedback []cc.PacketResult) { +func (g *SendSideBandwidthEstimator) OnFeedback(ts time.Time, feedback []types.PacketResult) { g.lossBased.updateLossStats(ts, feedback) } // GetBandwidthEstimation returns the estimated bandwidth available -func (g *SendSideBandwidthEstimator) GetBandwidthEstimation(now time.Time) cc.DataRate { - return cc.MaxDataRate(g.delayBased.getEstimate(), g.lossBased.getEstimate(cc.DataRate(math.MaxInt))) +func (g *SendSideBandwidthEstimator) GetBandwidthEstimation(now time.Time) types.DataRate { + return types.MinDataRate(g.delayBased.getEstimate(), g.lossBased.getEstimate(g.lastBWE)) }