Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit maximum tsn queued by the association #323

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@
// other constants
const (
acceptChSize = 16
// avgChunkSize is an estimate of the average chunk size. There is no theory behind
// this estimate.
avgChunkSize = 500
// minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue
// irrespective of the receive buffer size
// see Association.getMaxTSNOffset
minTSNOffset = 2000
// maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue
// irrespective of the receive buffer size
// see Association.getMaxTSNOffset
maxTSNOffset = 40000
)

func getAssociationStateString(a uint32) string {
Expand Down Expand Up @@ -1110,6 +1121,23 @@
return a.srtt.Load().(float64) //nolint:forcetypeassert
}

// getMaxTSNOffset returns the maximum offset over the current cummulative TSN that
// we are willing to enqueue. Limiting the maximum offset limits the number of
// tsns we have in the payloadQueue map. This ensures that we don't use too much space in
// the map itself. This also ensures that we keep the bytes utilized in the receive
// buffer within a small multiple of the user provided max receive buffer size.
func (a *Association) getMaxTSNOffset() uint32 {
// 4 is a magic number here. There is no theory behind this.
offset := (a.maxReceiveBufferSize * 4) / avgChunkSize
if offset < minTSNOffset {
offset = minTSNOffset
}
if offset > maxTSNOffset {
offset = maxTSNOffset
}

Check warning on line 1137 in association.go

View check run for this annotation

Codecov / codecov/patch

association.go#L1136-L1137

Added lines #L1136 - L1137 were not covered by tests
return offset
}

func setSupportedExtensions(init *chunkInitCommon) {
// nolint:godox
// TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2
Expand Down Expand Up @@ -1378,7 +1406,7 @@
a.name, d.tsn, d.immediateSack, len(d.userData))
a.stats.incDATAs()

canPush := a.payloadQueue.canPush(d, a.peerLastTSN)
canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset())
if canPush {
s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
if s == nil {
Expand Down
216 changes: 216 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -2731,6 +2732,221 @@ loop:
return a1, a2, nil
}

// udpDiscardReader blocks all reads after block is set to true.
// This allows us to send arbitrary packets on a stream and block the packets received in response
type udpDiscardReader struct {
net.Conn
ctx context.Context
block atomic.Bool
}

func (d *udpDiscardReader) Read(b []byte) (n int, err error) {
if d.block.Load() {
<-d.ctx.Done()
return 0, d.ctx.Err()
}
return d.Conn.Read(b)
}

func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) {
loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan interface{})
a2Chan := make(chan interface{})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn1,
LoggerFactory: loggerFactory,
})
if err2 != nil {
a1Chan <- err2
} else {
a1Chan <- a
}
}()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn2,
LoggerFactory: loggerFactory,
MaxReceiveBufferSize: 100_000,
})
if err2 != nil {
a2Chan <- err2
} else {
a2Chan <- a
}
}()

var a1 *Association
var a2 *Association

loop:
for {
select {
case v1 := <-a1Chan:
switch v := v1.(type) {
case *Association:
a1 = v
if a2 != nil {
break loop
}
case error:
return nil, nil, v
}
case v2 := <-a2Chan:
switch v := v2.(type) {
case *Association:
a2 = v
if a1 != nil {
break loop
}
case error:
return nil, nil, v
}
}
}
return a1, a2, nil
}

func noErrorClose(t *testing.T, closeF func() error) {
t.Helper()
require.NoError(t, closeF())
}

// readMyNextTSN uses a lock to read the myNextTSN field of the association.
// Avoids a data race.
func readMyNextTSN(a *Association) uint32 {
a.lock.Lock()
defer a.lock.Unlock()
return a.myNextTSN
}

func TestAssociationReceiveWindow(t *testing.T) {
udp1, udp2 := createUDPConnPair()
ctx, cancel := context.WithCancel(context.Background())
dudp1 := &udpDiscardReader{Conn: udp1, ctx: ctx}
// a1 is the association used for sending data
// a2 is the association with receive window of 100kB which we will
// try to bypass
a1, a2, err := createAssociationPair(dudp1, udp2)
require.NoError(t, err)
defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
dudp1.block.Store(true)

_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
s2, err := a2.AcceptStream()
require.NoError(t, err)
require.Equal(t, uint16(1), s2.streamIdentifier)

done := make(chan bool)
go func() {
chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
chunk := chunks[0]
// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
chunk.tsn = readMyNextTSN(a1) + 1e9
for chunk.tsn > readMyNextTSN(a1) {
select {
case <-done:
return
default:
}
chunk.tsn--
pp := a1.bundleDataChunksIntoPackets(chunks)
for _, p := range pp {
raw, err := p.marshal(true)
if err != nil {
return
}
_, err = a1.netConn.Write(raw)
if err != nil {
return
}
}
if chunk.tsn%10 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
}()

for cnt := 0; cnt < 15; cnt++ {
bytesQueued := s2.getNumBytesInReassemblyQueue()
if bytesQueued > 5_000_000 {
t.Error("too many bytes enqueued with receive window of 10kb", bytesQueued)
break
}
t.Log("bytes queued", bytesQueued)
time.Sleep(1 * time.Second)
}
close(done)
cancel()
}

func TestAssociationMaxTSNOffset(t *testing.T) {
udp1, udp2 := createUDPConnPair()
// a1 is the association used for sending data
// a2 is the association with receive window of 100kB which we will
// try to bypass
a1, a2, err := createAssociationPair(udp1, udp2)
require.NoError(t, err)
defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
s2, err := a2.AcceptStream()
require.NoError(t, err)
require.Equal(t, uint16(1), s2.streamIdentifier)

chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
chunks = chunks[:1]
sendChunk := func(tsn uint32) {
chunk := chunks[0]
// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
chunk.tsn = tsn
pp := a1.bundleDataChunksIntoPackets(chunks)
for _, p := range pp {
raw, err := p.marshal(true)
if err != nil {
t.Fatal(err)
return
}
_, err = a1.netConn.Write(raw)
if err != nil {
t.Fatal(err)
return
}
}
}
sendChunk(readMyNextTSN(a1) + 100_000)
time.Sleep(100 * time.Millisecond)
require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000)

sendChunk(readMyNextTSN(a1) + 10_000)
time.Sleep(100 * time.Millisecond)
require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000)

sendChunk(readMyNextTSN(a1) + minTSNOffset - 100)
time.Sleep(100 * time.Millisecond)
require.Greater(t, s2.getNumBytesInReassemblyQueue(), 1000)
}

func TestAssociation_Shutdown(t *testing.T) {
checkGoroutineLeaks(t)

Expand Down
4 changes: 2 additions & 2 deletions payload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func (q *payloadQueue) updateSortedKeys() {
})
}

func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool {
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) {
return false
}
return true
Expand Down
Loading