diff --git a/internal/mux/mux.go b/internal/mux/mux.go index 81ea005066c..5ade1cd4fcf 100644 --- a/internal/mux/mux.go +++ b/internal/mux/mux.go @@ -9,6 +9,7 @@ import ( "io" "net" "sync" + "time" "github.com/pion/ice/v3" "github.com/pion/logging" @@ -34,20 +35,25 @@ type Mux struct { bufferSize int closedCh chan struct{} + pendingPackets map[*pendingPacket]struct{} + pendingPacketsLock sync.Mutex + log logging.LeveledLogger } // NewMux creates a new Mux func NewMux(config Config) *Mux { m := &Mux{ - nextConn: config.Conn, - endpoints: make(map[*Endpoint]MatchFunc), - bufferSize: config.BufferSize, - closedCh: make(chan struct{}), - log: config.LoggerFactory.NewLogger("mux"), + nextConn: config.Conn, + endpoints: make(map[*Endpoint]MatchFunc), + bufferSize: config.BufferSize, + closedCh: make(chan struct{}), + log: config.LoggerFactory.NewLogger("mux"), + pendingPackets: make(map[*pendingPacket]struct{}), } go m.readLoop() + go m.pendingPacketsHandler() return m } @@ -140,10 +146,17 @@ func (m *Mux) dispatch(buf []byte) error { if endpoint == nil { if len(buf) > 0 { - m.log.Warnf("Warning: mux: no endpoint for packet starting with %d", buf[0]) + m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, queueing packet as pending", buf[0]) } else { - m.log.Warnf("Warning: mux: no endpoint for zero length packet") + m.log.Warnf("Warning: mux: no endpoint for zero length packet, queueing packet as pending") } + + m.pendingPacketsLock.Lock() + m.pendingPackets[&pendingPacket{ + t: time.Now(), + data: buf, + }] = struct{}{} + m.pendingPacketsLock.Unlock() return nil } @@ -157,3 +170,38 @@ func (m *Mux) dispatch(buf []byte) error { return err } + +func (m *Mux) pendingPacketsHandler() { + ticker := time.NewTicker(time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + m.lock.RLock() + m.pendingPacketsLock.Lock() + for p := range m.pendingPackets { + if time.Since(p.t) > time.Second*5 { + m.log.Warnf("Warning: mux: dropping packet after 5 seconds in pending queue") + delete(m.pendingPackets, p) + } + for endpoint, f := range m.endpoints { + if f(p.data) { + _, _ = endpoint.buffer.Write(p.data) + delete(m.pendingPackets, p) + m.log.Warnf("Warning: mux: found endpoint for packet after %s in pending queue", time.Since(p.t)) + } + } + } + m.pendingPacketsLock.Unlock() + m.lock.RUnlock() + case <-m.closedCh: + return + } + } +} + +type pendingPacket struct { + t time.Time + data []byte +}