Skip to content

Commit

Permalink
Merge branch 'master' into fastrecover_missing
Browse files Browse the repository at this point in the history
  • Loading branch information
cnderrauber authored Nov 20, 2024
2 parents 5f604a4 + 7d6927e commit af751dd
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 16 deletions.
41 changes: 34 additions & 7 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ type Association struct {
partialBytesAcked uint32
inFastRecovery bool
fastRecoverExitPoint uint32
minCwnd uint32 // Minimum congestion window
fastRtxWnd uint32 // Send window for fast retransmit
cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance

// RTX & Ack timer
rtoMgr *rtoManager
Expand Down Expand Up @@ -261,8 +264,16 @@ type Config struct {
MaxMessageSize uint32
EnableZeroChecksum bool
LoggerFactory logging.LoggerFactory

// congestion control configuration
// RTOMax is the maximum retransmission timeout in milliseconds
RTOMax float64
// Minimum congestion window
MinCwnd uint32
// Send window for fast retransmit
FastRtxWnd uint32
// Step of congestion window increase at Congestion Avoidance
CwndCAStep uint32
}

// Server accepts a SCTP stream over a conn
Expand Down Expand Up @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association {
netConn: config.NetConn,
maxReceiveBufferSize: maxReceiveBufferSize,
maxMessageSize: maxMessageSize,
minCwnd: config.MinCwnd,
fastRtxWnd: config.FastRtxWnd,
cwndCAStep: config.CwndCAStep,

// These two max values have us not need to follow
// 5.1.1 where this peer may be incapable of supporting
Expand Down Expand Up @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
if a.willRetransmitFast {
a.willRetransmitFast = false

toFastRetrans := []chunk{}
toFastRetrans := []*chunkPayloadData{}
fastRetransSize := commonHeaderSize

fastRetransWnd := a.MTU()
if fastRetransWnd < a.fastRtxWnd {
fastRetransWnd = a.fastRtxWnd
}
for i := 0; ; i++ {
c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1)
if !ok {
Expand All @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
// packet.

dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData))
if a.MTU() < fastRetransSize+dataChunkSize {
if fastRetransWnd < fastRetransSize+dataChunkSize {
break
}

Expand All @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
}

if len(toFastRetrans) > 0 {
raw, err := a.marshalPacket(a.createPacket(toFastRetrans))
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
} else {
for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) {
raw, err := a.marshalPacket(p)
if err != nil {
a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name)
continue
}
rawPackets = append(rawPackets, raw)
}
}
Expand Down Expand Up @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 {
}

func (a *Association) setCWND(cwnd uint32) {
if cwnd < a.minCwnd {
cwnd = a.minCwnd
}
atomic.StoreUint32(&a.cwnd, cwnd)
}

Expand Down Expand Up @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) {
// reset partial_bytes_acked to (partial_bytes_acked - cwnd).
if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 {
a.partialBytesAcked -= a.CWND()
a.setCWND(a.CWND() + a.MTU())
step := a.MTU()
if step < a.cwndCAStep {
step = a.cwndCAStep
}
a.setCWND(a.CWND() + step)
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)",
a.name, a.CWND(), a.ssthresh, totalBytesAcked)
}
Expand Down
104 changes: 95 additions & 9 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,7 @@ func TestAssocCongestionControl(t *testing.T) {
br := test.NewBridge()

a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize)
a0.cwndCAStep = 2800 // 2 mtu
if !assert.Nil(t, err, "failed to create associations") {
assert.FailNow(t, "failed due to earlier error")
}
Expand Down Expand Up @@ -2735,6 +2736,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) {
}

func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) {
return createAssociationPairWithConfig(udpConn1, udpConn2, Config{})
}

func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) {
loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan interface{})
Expand All @@ -2744,10 +2749,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
defer cancel()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn1,
LoggerFactory: loggerFactory,
})
cfg := config
cfg.NetConn = udpConn1
cfg.LoggerFactory = loggerFactory
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a1Chan <- err2
} else {
Expand All @@ -2756,11 +2761,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association,
}()

go func() {
a, err2 := createClientWithContext(ctx, Config{
NetConn: udpConn2,
LoggerFactory: loggerFactory,
MaxReceiveBufferSize: 100_000,
})
cfg := config
cfg.NetConn = udpConn2
cfg.LoggerFactory = loggerFactory
if cfg.MaxReceiveBufferSize == 0 {
cfg.MaxReceiveBufferSize = 100_000
}
a, err2 := createClientWithContext(ctx, cfg)
if err2 != nil {
a2Chan <- err2
} else {
Expand Down Expand Up @@ -2880,6 +2887,85 @@ func TestAssociationReceiveWindow(t *testing.T) {
cancel()
}

func TestAssociationFastRtxWnd(t *testing.T) {
udp1, udp2 := createUDPConnPair()
a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000})
require.NoError(t, err)
defer noErrorClose(t, a2.Close)
defer noErrorClose(t, a1.Close)
s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary)
require.NoError(t, err)
defer noErrorClose(t, s1.Close)
_, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary)
require.NoError(t, err)
_, err = a2.AcceptStream()
require.NoError(t, err)

a1.rtoMgr.setRTO(1000, true)
// ack the hello packet
time.Sleep(1 * time.Second)

require.Equal(t, a1.minCwnd, a1.CWND())

var shouldDrop atomic.Bool
var dropCounter atomic.Uint32
dbConn1, ok := udp1.(*dumbConn2)
require.True(t, ok)
dbConn2, ok := udp2.(*dumbConn2)
require.True(t, ok)
dbConn1.remoteInboundHandler = func(packet []byte) {
if !shouldDrop.Load() {
dbConn2.inboundHandler(packet)
} else {
dropCounter.Add(1)
}
}

shouldDrop.Store(true)
// send packets and dropped
buf := make([]byte, 1000)
for i := 0; i < 10; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
}

require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load())
// send packets to trigger fast retransmit
shouldDrop.Store(false)

require.Zero(t, a1.stats.getNumFastRetrans())
require.False(t, a1.inFastRecovery)

// wait SACK
sackCh := make(chan []byte, 1)
dbConn2.remoteInboundHandler = func(buf []byte) {
p := &packet{}
require.NoError(t, p.unmarshal(true, buf))
for _, c := range p.chunks {
if _, ok := c.(*chunkSelectiveAck); ok {
select {
case sackCh <- buf:
default:
}
return
}
}
}
// wait sack to trigger fast retransmit
for i := 0; i < 3; i++ {
_, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary)
require.NoError(t, err)
dbConn1.inboundHandler(<-sackCh)
}
// fast retransmit and new sack sent
require.Eventually(t, func() bool {
a1.lock.RLock()
defer a1.lock.RUnlock()
return a1.inFastRecovery
}, 5*time.Second, 10*time.Millisecond)
require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans())
}

func TestAssociationMaxTSNOffset(t *testing.T) {
udp1, udp2 := createUDPConnPair()
// a1 is the association used for sending data
Expand Down

0 comments on commit af751dd

Please sign in to comment.