diff --git a/pkg/gcc/adaptive_threshold.go b/pkg/gcc/adaptive_threshold.go index 445b86bb..514607a1 100644 --- a/pkg/gcc/adaptive_threshold.go +++ b/pkg/gcc/adaptive_threshold.go @@ -92,7 +92,7 @@ func (a *adaptiveThreshold) update(estimate time.Duration) { timeDelta := time.Duration(minInt(int(now.Sub(a.lastUpdate).Milliseconds()), int(maxTimeDelta.Milliseconds()))) * time.Millisecond d := absEstimate - a.thresh add := k * float64(d.Milliseconds()) * float64(timeDelta.Milliseconds()) - a.thresh += time.Duration(add * 1000) * time.Microsecond + a.thresh += time.Duration(add*1000) * time.Microsecond a.thresh = clampDuration(a.thresh, a.min, a.max) a.lastUpdate = now } diff --git a/pkg/stats/stats_recorder.go b/pkg/stats/stats_recorder.go index 11bcc9e2..339dbdd4 100644 --- a/pkg/stats/stats_recorder.go +++ b/pkg/stats/stats_recorder.go @@ -1,6 +1,8 @@ package stats import ( + "sync" + "sync/atomic" "time" "github.com/pion/interceptor" @@ -78,12 +80,9 @@ type recorder struct { maxLastSenderReports int maxLastReceiverReferenceTimes int - incomingRTPChan chan *incomingRTP - incomingRTCPChan chan *incomingRTCP - outgoingRTPChan chan *outgoingRTP - outgoingRTCPChan chan *outgoingRTCP - getStatsChan chan Stats - done chan struct{} + latestStats *internalStats + ms *sync.Mutex // Locks latestStats + running uint32 } func newRecorder(ssrc uint32, clockRate float64) *recorder { @@ -93,21 +92,24 @@ func newRecorder(ssrc uint32, clockRate float64) *recorder { clockRate: clockRate, maxLastSenderReports: 5, maxLastReceiverReferenceTimes: 5, - incomingRTPChan: make(chan *incomingRTP), - incomingRTCPChan: make(chan *incomingRTCP), - outgoingRTPChan: make(chan *outgoingRTP), - outgoingRTCPChan: make(chan *outgoingRTCP), - getStatsChan: make(chan Stats), - done: make(chan struct{}), + latestStats: &internalStats{}, + ms: &sync.Mutex{}, } } func (r *recorder) Stop() { - close(r.done) + atomic.StoreUint32(&r.running, 0) } func (r *recorder) GetStats() Stats { - return <-r.getStatsChan + r.ms.Lock() + defer r.ms.Unlock() + return Stats{ + InboundRTPStreamStats: r.latestStats.InboundRTPStreamStats, + OutboundRTPStreamStats: r.latestStats.OutboundRTPStreamStats, + RemoteInboundRTPStreamStats: r.latestStats.RemoteInboundRTPStreamStats, + RemoteOutboundRTPStreamStats: r.latestStats.RemoteOutboundRTPStreamStats, + } } func (r *recorder) recordIncomingRTP(latestStats internalStats, v *incomingRTP) internalStats { @@ -261,38 +263,13 @@ func (r *recorder) recordIncomingRTCP(latestStats internalStats, v *incomingRTCP } func (r *recorder) Start() { - latestStats := &internalStats{} - for { - select { - case <-r.done: - return - case v := <-r.incomingRTPChan: - s := r.recordIncomingRTP(*latestStats, v) - latestStats = &s - - case v := <-r.outgoingRTCPChan: - s := r.recordOutgoingRTCP(*latestStats, v) - latestStats = &s - - case v := <-r.outgoingRTPChan: - s := r.recordOutgoingRTP(*latestStats, v) - latestStats = &s - - case v := <-r.incomingRTCPChan: - s := r.recordIncomingRTCP(*latestStats, v) - latestStats = &s - - case r.getStatsChan <- Stats{ - InboundRTPStreamStats: latestStats.InboundRTPStreamStats, - OutboundRTPStreamStats: latestStats.OutboundRTPStreamStats, - RemoteInboundRTPStreamStats: latestStats.RemoteInboundRTPStreamStats, - RemoteOutboundRTPStreamStats: latestStats.RemoteOutboundRTPStreamStats, - }: - } - } + atomic.StoreUint32(&r.running, 1) } func (r *recorder) QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } if attr == nil { attr = make(interceptor.Attributes) } @@ -302,15 +279,20 @@ func (r *recorder) QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.A return } hdr := header.Clone() - r.incomingRTPChan <- &incomingRTP{ + r.ms.Lock() + *r.latestStats = r.recordIncomingRTP(*r.latestStats, &incomingRTP{ ts: ts, header: hdr, payloadLen: len(buf) - hdr.MarshalSize(), attr: attr, - } + }) + r.ms.Unlock() } func (r *recorder) QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } if attr == nil { attr = make(interceptor.Attributes) } @@ -319,29 +301,41 @@ func (r *recorder) QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor. r.logger.Warnf("failed to get RTCP packets, skipping incoming RTCP packet in stats calculation: %v", err) return } - r.incomingRTCPChan <- &incomingRTCP{ + r.ms.Lock() + *r.latestStats = r.recordIncomingRTCP(*r.latestStats, &incomingRTCP{ ts: ts, pkts: pkts, attr: attr, - } + }) + r.ms.Unlock() } func (r *recorder) QueueOutgoingRTP(ts time.Time, header *rtp.Header, payload []byte, attr interceptor.Attributes) { + if atomic.LoadUint32(&r.running) == 0 { + return + } hdr := header.Clone() - r.outgoingRTPChan <- &outgoingRTP{ + r.ms.Lock() + *r.latestStats = r.recordOutgoingRTP(*r.latestStats, &outgoingRTP{ ts: ts, header: hdr, payloadLen: len(payload), attr: attr, - } + }) + r.ms.Unlock() } func (r *recorder) QueueOutgoingRTCP(ts time.Time, pkts []rtcp.Packet, attr interceptor.Attributes) { - r.outgoingRTCPChan <- &outgoingRTCP{ + if atomic.LoadUint32(&r.running) == 0 { + return + } + r.ms.Lock() + *r.latestStats = r.recordOutgoingRTCP(*r.latestStats, &outgoingRTCP{ ts: ts, pkts: pkts, attr: attr, - } + }) + r.ms.Unlock() } func min(a, b int) int { diff --git a/pkg/stats/stats_recorder_test.go b/pkg/stats/stats_recorder_test.go index bc2e9b47..0fcf2eba 100644 --- a/pkg/stats/stats_recorder_test.go +++ b/pkg/stats/stats_recorder_test.go @@ -1,10 +1,13 @@ package stats import ( + "context" + "errors" "fmt" "testing" "time" + "github.com/pion/interceptor" "github.com/pion/interceptor/internal/ntp" "github.com/pion/rtcp" "github.com/pion/rtp" @@ -260,8 +263,7 @@ func TestStatsRecorder(t *testing.T) { t.Run(fmt.Sprintf("%v:%v", i, cc.name), func(t *testing.T) { r := newRecorder(0, 90_000) - go r.Start() - defer r.Stop() + r.Start() for _, record := range cc.records { switch v := record.content.(type) { @@ -282,6 +284,8 @@ func TestStatsRecorder(t *testing.T) { s := r.GetStats() + r.Stop() + assert.Equal(t, cc.expectedInboundRTPStreamStats, s.InboundRTPStreamStats) assert.Equal(t, cc.expectedOutboundRTPStreamStats, s.OutboundRTPStreamStats) assert.Equal(t, cc.expectedRemoteInboundRTPStreamStats, s.RemoteInboundRTPStreamStats) @@ -313,3 +317,75 @@ func TestStatsRecorder_DLRR_Precision(t *testing.T) { assert.Equal(t, int64(s.RemoteOutboundRTPStreamStats.RoundTripTime), int64(-9223372036854775808)) } + +func TestGetStatsNotBlocking(t *testing.T) { + r := newRecorder(0, 90_000) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + go func() { + defer cancel() + r.Start() + r.GetStats() + }() + go r.Stop() + + <-ctx.Done() + + if err := ctx.Err(); err != nil && errors.Is(err, context.DeadlineExceeded) { + t.Error("it shouldn't block") + } +} + +func TestQueueNotBlocking(t *testing.T) { + for _, i := range []struct { + f func(r *recorder) + name string + }{ + { + f: func(r *recorder) { + r.QueueIncomingRTP(time.Now(), mustMarshalRTP(t, rtp.Packet{}), interceptor.Attributes{}) + }, + name: "QueueIncomingRTP", + }, + { + f: func(r *recorder) { + r.QueueOutgoingRTP(time.Now(), &rtp.Header{}, mustMarshalRTP(t, rtp.Packet{}), interceptor.Attributes{}) + }, + name: "QueueOutgoingRTP", + }, + { + f: func(r *recorder) { + r.QueueIncomingRTCP(time.Now(), mustMarshalRTCPs(t, &rtcp.CCFeedbackReport{}), interceptor.Attributes{}) + }, + name: "QueueIncomingRTCP", + }, + { + f: func(r *recorder) { + r.QueueOutgoingRTCP(time.Now(), []rtcp.Packet{}, interceptor.Attributes{}) + }, + name: "QueueOutgoingRTCP", + }, + } { + t.Run(i.name+"NotBlocking", func(t *testing.T) { + r := newRecorder(0, 90_000) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + go func() { + defer cancel() + r.Start() + i.f(r) + }() + go r.Stop() + + <-ctx.Done() + + if err := ctx.Err(); err != nil && errors.Is(err, context.DeadlineExceeded) { + t.Error("it shouldn't block") + } + }) + } +}