From d64187f4263e271957d2e22dce3c245ccbea7bd0 Mon Sep 17 00:00:00 2001
From: sukun <sukunrt@gmail.com>
Date: Fri, 29 Mar 2024 02:05:50 +0530
Subject: [PATCH] make the maxTSNOffset a function of the receive window

---
 association.go      |  36 +++++++++++---
 association_test.go | 116 ++++++++++++++++++++++++++++++++++++++++++++
 payload_queue.go    |   2 +-
 3 files changed, 146 insertions(+), 8 deletions(-)

diff --git a/association.go b/association.go
index 75b26397..a5310d6f 100644
--- a/association.go
+++ b/association.go
@@ -99,12 +99,17 @@ const (
 // other constants
 const (
 	acceptChSize = 16
-	// maxTSNOffset is the maximum offset of a received chunk TSN from the cummulative TSN
-	// we have seen so far that we will enqueue.
-	// For a chunk to be enqueued chunk.tsn < cummulativeTSN + maxTSNOffset
-	// This allows us to not enqueue too many bytes over the receive window in case of out
-	// of order delivery. A buffer of 1000 TSNs implies an excess of roughly 2MB.
-	maxTSNOffset = 2000
+	// avgChunkSize is an estimate of the average chunk size. There is no theory behind
+	// this estimate.
+	avgChunkSize = 500
+	// minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue
+	// irrespective of the receive buffer size
+	// see Association.getMaxTSNOffset
+	minTSNOffset = 2000
+	// maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue
+	// irrespective of the receive buffer size
+	// see Association.getMaxTSNOffset
+	maxTSNOffset = 40000
 )
 
 func getAssociationStateString(a uint32) string {
@@ -1116,6 +1121,23 @@ func (a *Association) SRTT() float64 {
 	return a.srtt.Load().(float64) //nolint:forcetypeassert
 }
 
+// getMaxTSNOffset returns the maximum offset over the current cummulative TSN that
+// we are willing to enqueue. Limiting the maximum offset limits the number of
+// tsns we have in the payloadQueue map. This ensures that we don't use too much space in
+// the map itself. This also ensures that we keep the bytes utilised in the receive
+// buffer within a small multiple of the user provided max receive buffer size.
+func (a *Association) getMaxTSNOffset() uint32 {
+	// 4 is a magic number here. There is no theory behind this.
+	offset := (a.maxReceiveBufferSize * 4) / avgChunkSize
+	if offset < minTSNOffset {
+		offset = minTSNOffset
+	}
+	if offset > maxTSNOffset {
+		offset = maxTSNOffset
+	}
+	return offset
+}
+
 func setSupportedExtensions(init *chunkInitCommon) {
 	// nolint:godox
 	// TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2
@@ -1384,7 +1406,7 @@ func (a *Association) handleData(d *chunkPayloadData) []*packet {
 		a.name, d.tsn, d.immediateSack, len(d.userData))
 	a.stats.incDATAs()
 
-	canPush := a.payloadQueue.canPush(d, a.peerLastTSN)
+	canPush := a.payloadQueue.canPush(d, a.peerLastTSN, a.getMaxTSNOffset())
 	if canPush {
 		s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown)
 		if s == nil {
diff --git a/association_test.go b/association_test.go
index 2c8681e1..89c25e24 100644
--- a/association_test.go
+++ b/association_test.go
@@ -2879,6 +2879,122 @@ func TestAssociationReceiveWindow(t *testing.T) {
 	cancel()
 }
 
+func TestAssociationMaxTSNOffset(t *testing.T) {
+	udp1, udp2 := createUDPConnPair()
+	createAssociations := func() (*Association, *Association, error) {
+		loggerFactory := logging.NewDefaultLoggerFactory()
+
+		a1Chan := make(chan interface{})
+		a2Chan := make(chan interface{})
+
+		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+		defer cancel()
+
+		go func() {
+			a, err2 := createClientWithContext(ctx, Config{
+				NetConn:       udp1,
+				LoggerFactory: loggerFactory,
+			})
+			if err2 != nil {
+				a1Chan <- err2
+			} else {
+				a1Chan <- a
+			}
+		}()
+
+		go func() {
+			a, err2 := createClientWithContext(ctx, Config{
+				NetConn:              udp2,
+				LoggerFactory:        loggerFactory,
+				MaxReceiveBufferSize: 100_000,
+			})
+			if err2 != nil {
+				a2Chan <- err2
+			} else {
+				a2Chan <- a
+			}
+		}()
+
+		var a1 *Association
+		var a2 *Association
+
+	loop:
+		for {
+			select {
+			case v1 := <-a1Chan:
+				switch v := v1.(type) {
+				case *Association:
+					a1 = v
+					if a2 != nil {
+						break loop
+					}
+				case error:
+					return nil, nil, v
+				}
+			case v2 := <-a2Chan:
+				switch v := v2.(type) {
+				case *Association:
+					a2 = v
+					if a1 != nil {
+						break loop
+					}
+				case error:
+					return nil, nil, v
+				}
+			}
+		}
+		return a1, a2, nil
+	}
+	// a1 is the association used for sending data
+	// a2 is the association with receive window of 100kB which we will
+	// try to bypass
+	a1, a2, err := createAssociations()
+
+	require.NoError(t, err)
+	defer a2.Close()
+	defer a1.Close()
+	s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
+	require.NoError(t, err)
+	defer s1.Close()
+	s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
+	s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
+	s2, err := a2.AcceptStream()
+	require.NoError(t, err)
+	require.Equal(t, uint16(1), s2.streamIdentifier)
+
+	chunks := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary)
+	chunks = chunks[:1]
+	sendChunk := func(tsn uint32) {
+		chunk := chunks[0]
+		// Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue
+		chunk.tsn = tsn
+		pp := a1.bundleDataChunksIntoPackets(chunks)
+		for _, p := range pp {
+			raw, err := p.marshal(true)
+			if err != nil {
+				t.Fatal(err)
+				return
+			}
+			_, err = a1.netConn.Write(raw)
+			if err != nil {
+				t.Fatal(err)
+				return
+			}
+		}
+	}
+	sendChunk(a1.myNextTSN + 100_000)
+	time.Sleep(100 * time.Millisecond)
+	require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000)
+
+	sendChunk(a1.myNextTSN + 10_000)
+	time.Sleep(100 * time.Millisecond)
+	require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000)
+
+	sendChunk(a1.myNextTSN + minTSNOffset - 100)
+	time.Sleep(100 * time.Millisecond)
+	require.Greater(t, s2.getNumBytesInReassemblyQueue(), 1000)
+}
+
 func TestAssociation_Shutdown(t *testing.T) {
 	checkGoroutineLeaks(t)
 
diff --git a/payload_queue.go b/payload_queue.go
index 1510bb6f..a0b1b26f 100644
--- a/payload_queue.go
+++ b/payload_queue.go
@@ -36,7 +36,7 @@ func (q *payloadQueue) updateSortedKeys() {
 	})
 }
 
-func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
+func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32, maxTSNOffset uint32) bool {
 	_, ok := q.chunkMap[p.tsn]
 	if ok || sna32LTE(p.tsn, cumulativeTSN) || sna32GTE(p.tsn, cumulativeTSN+maxTSNOffset) {
 		return false