Skip to content

Commit

Permalink
Adding pendingQueue to internal/mux
Browse files Browse the repository at this point in the history
Buffer a small amount of packets in the internal/mux to allow remotes to
send DTLS traffic before ICE has completed
  • Loading branch information
lactyy authored and Sean-Der committed Aug 1, 2024
1 parent cbbb1c2 commit 5aa3ff6
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 27 deletions.
86 changes: 65 additions & 21 deletions internal/mux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,27 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/pion/ice/v3"
"github.com/pion/logging"
"github.com/pion/transport/v3/packetio"
)

// The maximum amount of data that can be buffered before returning errors.
const maxBufferSize = 1000 * 1000 // 1MB
const (
// The maximum amount of data that can be buffered before returning errors.
maxBufferSize = 1000 * 1000 // 1MB

// How many total pending packets can be cached
maxPendingPackets = 15

// How long to attempt to deliver a packet that doesn't match any Endpoints
maxPendingPacketTime = time.Second * 5

// How often to query for endpoints in pending mux dispatch
pendingPacketEndpointQueryInterval = time.Millisecond * 20
)

// Config collects the arguments to mux.Mux construction into
// a single structure
Expand All @@ -28,11 +41,12 @@ type Config struct {

// Mux allows multiplexing
type Mux struct {
lock sync.RWMutex
nextConn net.Conn
endpoints map[*Endpoint]MatchFunc
bufferSize int
closedCh chan struct{}
lock sync.RWMutex
nextConn net.Conn
endpoints map[*Endpoint]MatchFunc
bufferSize int
closedCh chan struct{}
pendingPacketsDispatched uint64

log logging.LeveledLogger
}
Expand Down Expand Up @@ -127,33 +141,63 @@ func (m *Mux) readLoop() {
}

func (m *Mux) dispatch(buf []byte) error {
var endpoint *Endpoint

m.lock.Lock()
for e, f := range m.endpoints {
if f(buf) {
endpoint = e
break
}
if len(buf) == 0 {
m.log.Warnf("Warning: mux: unable to dispatch zero length packet")
return nil

Check warning on line 146 in internal/mux/mux.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/mux.go#L145-L146

Added lines #L145 - L146 were not covered by tests
}
m.lock.Unlock()

endpoint := m.findEndpoint(buf)
if endpoint == nil {
if len(buf) > 0 {
m.log.Warnf("Warning: mux: no endpoint for packet starting with %d", buf[0])
if totalPacketsDispatched := atomic.AddUint64(&m.pendingPacketsDispatched, 1); totalPacketsDispatched >= maxPendingPackets {
m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)", buf[0], totalPacketsDispatched)

Check warning on line 152 in internal/mux/mux.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/mux.go#L152

Added line #L152 was not covered by tests
} else {
m.log.Warnf("Warning: mux: no endpoint for zero length packet")
m.log.Warnf("Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)", buf[0], totalPacketsDispatched)
go m.queuePendingPacket(buf)
}
return nil
}

_, err := endpoint.buffer.Write(buf)

// Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
_, err := endpoint.buffer.Write(buf)
if errors.Is(err, packetio.ErrFull) {
m.log.Infof("mux: endpoint buffer is full, dropping packet")
return nil
}

return err
}

func (m *Mux) findEndpoint(buf []byte) *Endpoint {
m.lock.Lock()
defer m.lock.Unlock()

for e, f := range m.endpoints {
if f(buf) {
return e
}
}

return nil
}

func (m *Mux) queuePendingPacket(buf []byte) {
t, deadline := time.Now(), time.After(maxPendingPacketTime)
ticker := time.NewTicker(pendingPacketEndpointQueryInterval)
defer ticker.Stop()

for {
select {
case <-ticker.C:
if endpoint := m.findEndpoint(buf); endpoint != nil {
if _, err := endpoint.buffer.Write(buf); err != nil {
m.log.Warnf("Warning: mux: error writing packet to endpoint after %s in pending queue: %s", time.Since(t), err)

Check warning on line 193 in internal/mux/mux.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/mux.go#L193

Added line #L193 was not covered by tests
}
}
case <-deadline:
m.log.Warnf("Warning: mux: dropping packet after 5 seconds in pending queue")
return

Check warning on line 198 in internal/mux/mux.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/mux.go#L196-L198

Added lines #L196 - L198 were not covered by tests
case <-m.closedCh:
return
}
}
}
28 changes: 27 additions & 1 deletion internal/mux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import (
"github.com/stretchr/testify/require"
)

func endpointMatchAll([]byte) bool {
return true
}

const testPipeBufferSize = 8192

func TestNoEndpoints(t *testing.T) {
Expand Down Expand Up @@ -87,7 +91,7 @@ func TestNonFatalRead(t *testing.T) {
LoggerFactory: logging.NewDefaultLoggerFactory(),
})

e := m.NewEndpoint(MatchAll)
e := m.NewEndpoint(endpointMatchAll)

buff := make([]byte, testPipeBufferSize)
n, err := e.Read(buff)
Expand Down Expand Up @@ -154,3 +158,25 @@ func BenchmarkDispatch(b *testing.B) {
}
}
}

func TestPendingQueue(t *testing.T) {
factory := logging.NewDefaultLoggerFactory()
factory.DefaultLogLevel = logging.LogLevelDebug
m := &Mux{
endpoints: make(map[*Endpoint]MatchFunc),
log: factory.NewLogger("mux"),
}

inBuffer := []byte{20, 1, 2, 3, 4}
outBuffer := make([]byte, len(inBuffer))

require.NoError(t, m.dispatch(inBuffer))

endpoint := m.NewEndpoint(endpointMatchAll)
require.NotNil(t, endpoint)

_, err := endpoint.Read(outBuffer)
require.NoError(t, err)

require.Equal(t, outBuffer, inBuffer)
}
5 changes: 0 additions & 5 deletions internal/mux/muxfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ package mux
// MatchFunc allows custom logic for mapping packets to an Endpoint
type MatchFunc func([]byte) bool

// MatchAll always returns true
func MatchAll([]byte) bool {
return true
}

// MatchRange returns true if the first byte of buf is in [lower..upper]
func MatchRange(lower, upper byte, buf []byte) bool {
if len(buf) < 1 {
Expand Down

0 comments on commit 5aa3ff6

Please sign in to comment.