From cfd4eaf7735b1ab2745175363eb6abfa67deb574 Mon Sep 17 00:00:00 2001 From: OrlandoCo Date: Thu, 7 Jan 2021 09:53:58 -0600 Subject: [PATCH] fix(buffer): Fix buckets pools (#369) * fix(buffer): Fix buckets pools * Inline writing packets to avoid MTU limit * Remove downtracks from subs --- pkg/buffer/bucket.go | 15 +++------------ pkg/buffer/bucket_test.go | 4 ++-- pkg/buffer/buffer.go | 10 ++++------ pkg/buffer/buffer_test.go | 2 +- pkg/buffer/factory.go | 4 ++-- pkg/buffer/nack.go | 8 -------- pkg/buffer/nack_test.go | 16 +++++++++++++++- pkg/sfu/router.go | 15 ++++++++------- pkg/sfu/subscriber.go | 39 +++++++++++++++++++++++++++++---------- 9 files changed, 64 insertions(+), 49 deletions(-) diff --git a/pkg/buffer/bucket.go b/pkg/buffer/bucket.go index ed3fd2d1d..ff52bac79 100644 --- a/pkg/buffer/bucket.go +++ b/pkg/buffer/bucket.go @@ -20,10 +20,10 @@ type Bucket struct { onLost func(nack []rtcp.NackPair, askKeyframe bool) } -func NewBucket(size int, nack bool) *Bucket { +func NewBucket(buf []byte, nack bool) *Bucket { b := &Bucket{ - buf: make([]byte, size), - maxSteps: int(math.Floor(float64(size)/float64(maxPktSize))) - 1, + buf: buf, + maxSteps: int(math.Floor(float64(len(buf))/float64(maxPktSize))) - 1, } if nack { b.nacker = newNACKQueue() @@ -117,12 +117,3 @@ func (b *Bucket) set(sn uint16, pkt []byte) []byte { copy(b.buf[off+2:], pkt) return b.buf[off+2 : off+2+len(pkt)] } - -func (b *Bucket) reset() { - b.headSN = 0 - b.step = 0 - b.onLost = nil - if b.nacker != nil { - b.nacker.reset() - } -} diff --git a/pkg/buffer/bucket_test.go b/pkg/buffer/bucket_test.go index 1ee0fe68d..e387d7941 100644 --- a/pkg/buffer/bucket_test.go +++ b/pkg/buffer/bucket_test.go @@ -44,7 +44,7 @@ var TestPackets = []*rtp.Packet{ } func Test_queue(t *testing.T) { - q := NewBucket(2*1000*1000, true) + q := NewBucket(make([]byte, 25000), true) q.onLost = func(_ []rtcp.NackPair, _ bool) { } @@ -100,7 +100,7 @@ func Test_queue_edges(t *testing.T) { }, }, } - q := NewBucket(2*1000*1000, true) + q := NewBucket(make([]byte, 25000), true) q.onLost = func(_ []rtcp.NackPair, _ bool) { } q.headSN = 65532 diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 2febe87db..641c92216 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -111,12 +111,10 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, o Options) { switch { case strings.HasPrefix(codec.MimeType, "audio/"): b.codecType = webrtc.RTPCodecTypeAudio - b.bucket = b.audioPool.Get().(*Bucket) - b.bucket.reset() + b.bucket = NewBucket(b.audioPool.Get().([]byte), false) case strings.HasPrefix(codec.MimeType, "video/"): b.codecType = webrtc.RTPCodecTypeVideo - b.bucket = b.videoPool.Get().(*Bucket) - b.bucket.reset() + b.bucket = NewBucket(b.videoPool.Get().([]byte), true) default: b.codecType = webrtc.RTPCodecType(0) } @@ -224,10 +222,10 @@ func (b *Buffer) Close() error { b.closed = true if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeVideo { - b.videoPool.Put(b.bucket) + b.videoPool.Put(b.bucket.buf) } if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeAudio { - b.audioPool.Put(b.bucket) + b.audioPool.Put(b.bucket.buf) } b.onClose() close(b.packetChan) diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go index 2996b498d..189ca5599 100644 --- a/pkg/buffer/buffer_test.go +++ b/pkg/buffer/buffer_test.go @@ -90,7 +90,7 @@ func TestNewBuffer(t *testing.T) { } pool := &sync.Pool{ New: func() interface{} { - return NewBucket(2*1000*1000, true) + return make([]byte, 30000) }, } buff := NewBuffer(123, pool, pool) diff --git a/pkg/buffer/factory.go b/pkg/buffer/factory.go index b91c572fc..b353d2839 100644 --- a/pkg/buffer/factory.go +++ b/pkg/buffer/factory.go @@ -20,13 +20,13 @@ func NewBufferFactory() *Factory { videoPool: &sync.Pool{ New: func() interface{} { // Make a 2MB buffer for video - return NewBucket(2*1000*1000, true) + return make([]byte, 2*1000*1000) }, }, audioPool: &sync.Pool{ New: func() interface{} { // Make a max 25 packets buffer for audio - return NewBucket(maxPktSize*25, false) + return make([]byte, maxPktSize*25) }, }, rtpBuffers: make(map[uint32]*Buffer), diff --git a/pkg/buffer/nack.go b/pkg/buffer/nack.go index e22012a46..efb1ad573 100644 --- a/pkg/buffer/nack.go +++ b/pkg/buffer/nack.go @@ -30,14 +30,6 @@ func newNACKQueue() *nackQueue { } } -func (n *nackQueue) reset() { - n.maxSN = 0 - n.counter = 0 - n.cycles = 0 - n.kfSN = 0 - n.nacks = n.nacks[:0] -} - func (n *nackQueue) remove(sn uint16) { var extSN uint32 if sn > n.maxSN && sn&0x8000 == 1 && n.maxSN&0x8000 == 0 { diff --git a/pkg/buffer/nack_test.go b/pkg/buffer/nack_test.go index ab3404c5f..ec7907a7b 100644 --- a/pkg/buffer/nack_test.go +++ b/pkg/buffer/nack_test.go @@ -1,8 +1,10 @@ package buffer import ( + "math/rand" "reflect" "testing" + "time" "github.com/pion/rtcp" "github.com/stretchr/testify/assert" @@ -113,7 +115,7 @@ func Test_nackQueue_push(t *testing.T) { } } -func Test_nackQueue_pushAndNack(t *testing.T) { +func Test_nackQueue(t *testing.T) { type fields struct { nacks []nack cycles uint32 @@ -140,7 +142,19 @@ func Test_nackQueue_pushAndNack(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { + n := nackQueue{} + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < 100; i++ { + assert.NotPanics(t, func() { + n.push(uint16(r.Intn(60000))) + n.remove(uint16(r.Intn(60000))) + n.pairs() + }) + } + for _, sn := range n.nacks { + print(sn.sn, ",") + } }) } } diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go index 43836a46e..0c9d99c85 100644 --- a/pkg/sfu/router.go +++ b/pkg/sfu/router.go @@ -198,7 +198,7 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { return err } - outTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ + downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ MimeType: codec.MimeType, ClockRate: codec.ClockRate, Channels: codec.Channels, @@ -209,27 +209,28 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error { return err } // Create webrtc sender for the peer we are sending track to - if outTrack.transceiver, err = sub.pc.AddTransceiverFromTrack(outTrack, webrtc.RTPTransceiverInit{ + if downTrack.transceiver, err = sub.pc.AddTransceiverFromTrack(downTrack, webrtc.RTPTransceiverInit{ Direction: webrtc.RTPTransceiverDirectionSendonly, }); err != nil { return err } // nolint:scopelint - outTrack.OnCloseHandler(func() { - if err := sub.pc.RemoveTrack(outTrack.transceiver.Sender()); err != nil { + downTrack.OnCloseHandler(func() { + if err := sub.pc.RemoveTrack(downTrack.transceiver.Sender()); err != nil { log.Errorf("Error closing down track: %v", err) } else { + sub.RemoveDownTrack(recv.StreamID(), downTrack) sub.negotiate() } }) - outTrack.OnBind(func() { + downTrack.OnBind(func() { go sub.sendStreamDownTracksReports(recv.StreamID()) }) - sub.AddDownTrack(recv.StreamID(), outTrack) - recv.AddDownTrack(outTrack, r.config.Simulcast.BestQualityFirst) + sub.AddDownTrack(recv.StreamID(), downTrack) + recv.AddDownTrack(downTrack, r.config.Simulcast.BestQualityFirst) return nil } diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go index 06d46c402..34dffd4b7 100644 --- a/pkg/sfu/subscriber.go +++ b/pkg/sfu/subscriber.go @@ -2,7 +2,6 @@ package sfu import ( "io" - "math" "sync" "sync/atomic" "time" @@ -126,6 +125,23 @@ func (s *Subscriber) AddDownTrack(streamID string, downTrack *DownTrack) { } } +func (s *Subscriber) RemoveDownTrack(streamID string, downTrack *DownTrack) { + s.Lock() + defer s.Unlock() + if dts, ok := s.tracks[streamID]; ok { + idx := -1 + for i, dt := range dts { + if dt == downTrack { + idx = i + } + } + dts[idx] = dts[len(dts)-1] + dts[len(dts)-1] = nil + dts = dts[:len(dts)-1] + s.tracks[streamID] = dts + } +} + func (s *Subscriber) AddDataChannel(label string) (*webrtc.DataChannel, error) { s.Lock() defer s.Unlock() @@ -177,6 +193,10 @@ func (s *Subscriber) downTracksReports() { for { time.Sleep(5 * time.Second) + if s.pc.ConnectionState() == webrtc.ICETransportStateClosed { + return + } + var r []rtcp.Packet var sd []rtcp.SourceDescriptionChunk s.RLock() @@ -214,15 +234,16 @@ func (s *Subscriber) downTracksReports() { } } s.RUnlock() - i := math.Ceil(float64(len(sd)) / float64(20)) + i := 0 j := 0 - for i > 0 { - if i > 1 { - sd = sd[j*20 : (j+1)*20-1] - } else { - sd = sd[j*20 : cap(sd)] + for i < len(sd) { + i = (j + 1) * 15 + if i >= len(sd) { + i = len(sd) } - r = append(r, &rtcp.SourceDescription{Chunks: sd}) + nsd := sd[j*15 : i] + r = append(r, &rtcp.SourceDescription{Chunks: nsd}) + j++ if err := s.pc.WriteRTCP(r); err != nil { if err == io.EOF || err == io.ErrClosedPipe { return @@ -230,8 +251,6 @@ func (s *Subscriber) downTracksReports() { log.Errorf("Sending downtrack reports err: %v", err) } r = r[:0] - i-- - j++ } } }