diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 81ea005066c..20c399c8cb5 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -15,8 +15,13 @@ import ( "github.com/pion/transport/v3/packetio" ) -// The maximum amount of data that can be buffered before returning errors. -const maxBufferSize = 1000 * 1000 // 1MB +const ( + // The maximum amount of data that can be buffered before returning errors. + maxBufferSize = 1000 * 1000 // 1MB + + // How many total pending packets can be cached + maxPendingPackets = 15 +) // Config collects the arguments to mux.Mux construction into // a single structure @@ -34,6 +39,8 @@ type Mux struct { bufferSize int closedCh chan struct{} + pendingPackets [][]byte + log logging.LeveledLogger } @@ -66,6 +73,8 @@ func (m *Mux) NewEndpoint(f MatchFunc) *Endpoint { m.endpoints[e] = f m.lock.Unlock() + go m.handlePendingPackets(e, f) + return e } @@ -127,6 +136,11 @@ func (m *Mux) readLoop() { } func (m *Mux) dispatch(buf []byte) error { + if len(buf) == 0 { + m.log.Warnf("Warning: mux: unable to dispatch zero length packet") + return nil + } + var endpoint *Endpoint m.lock.Lock() @@ -139,11 +153,16 @@ func (m *Mux) dispatch(buf []byte) error { m.lock.Unlock() if endpoint == nil { - if len(buf) > 0 { - m.log.Warnf("Warning: mux: no endpoint for packet starting with %d", buf[0]) + m.lock.Lock() + defer m.lock.Unlock() + + if len(m.pendingPackets) >= maxPendingPackets { + m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", buf[0], len(m.pendingPackets)) } else { - m.log.Warnf("Warning: mux: no endpoint for zero length packet") + m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", buf[0], len(m.pendingPackets)) + m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...)) } + return nil } @@ -157,3 +176,20 @@ func (m *Mux) dispatch(buf []byte) error { return err } + +func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) { + m.lock.Lock() + defer m.lock.Unlock() + + pendingPackets := make([][]byte, len(m.pendingPackets)) + for _, buf := range m.pendingPackets { + if matchFunc(buf) { + if _, err := endpoint.buffer.Write(buf); err != nil { + m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err) + } + } else { + pendingPackets = append(pendingPackets, buf) + } + } + m.pendingPackets = pendingPackets +} diff --git a/internal/mux/mux_test.go b/internal/mux/mux_test.go index 75859b20765..7adbed5c411 100644 --- a/internal/mux/mux_test.go +++ b/internal/mux/mux_test.go @@ -154,3 +154,36 @@ func BenchmarkDispatch(b *testing.B) { } } } + +func TestPendingQueue(t *testing.T) { + factory := logging.NewDefaultLoggerFactory() + factory.DefaultLogLevel = logging.LogLevelDebug + m := &Mux{ + endpoints: make(map[*Endpoint]MatchFunc), + log: factory.NewLogger("mux"), + } + + // Assert empty packets don't end up in queue + require.NoError(t, m.dispatch([]byte{})) + require.Equal(t, len(m.pendingPackets), 0) + + // Test Happy Case + inBuffer := []byte{20, 1, 2, 3, 4} + outBuffer := make([]byte, len(inBuffer)) + + require.NoError(t, m.dispatch(inBuffer)) + + endpoint := m.NewEndpoint(MatchDTLS) + require.NotNil(t, endpoint) + + _, err := endpoint.Read(outBuffer) + require.NoError(t, err) + + require.Equal(t, outBuffer, inBuffer) + + // Assert limit on pendingPackets + for i := 0; i <= 100; i++ { + require.NoError(t, m.dispatch([]byte{64, 65, 66})) + } + require.Equal(t, len(m.pendingPackets), maxPendingPackets) +}