From e65b5595f56f30acd856122ed5c381be70ac60a0 Mon Sep 17 00:00:00 2001 From: cnderrauber Date: Mon, 18 Nov 2024 17:13:35 +0800 Subject: [PATCH] Add congestion control parameters to config The loss-based congestion control get poor performance under high bandwidth, high rtt and packet loss case since the congestion window becomes 1 mtu and increase slowly after retransmit timeout. And fast recovery retransmit cause exit slowly in consecutive packet loss. This change add paramters to the config then the user can set them to get higher throughput in such cases. --- association.go | 41 +++++++++++++++---- association_test.go | 99 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 124 insertions(+), 16 deletions(-) diff --git a/association.go b/association.go index 29e9978c..1f344e12 100644 --- a/association.go +++ b/association.go @@ -212,6 +212,9 @@ type Association struct { partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 + minCwnd uint32 // Minimum congestion window + fastRtxWnd uint32 // Send window for fast retransmit + cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance // RTX & Ack timer rtoMgr *rtoManager @@ -261,8 +264,16 @@ type Config struct { MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory + + // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds RTOMax float64 + // Minimum congestion window + MinCwnd uint32 + // Send window for fast retransmit + FastRtxWnd uint32 + // Step of congestion window increase at Congestion Avoidance + CwndCAStep uint32 } // Server accepts a SCTP stream over a conn @@ -325,6 +336,9 @@ func createAssociation(config Config) *Association { netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, + minCwnd: config.MinCwnd, + fastRtxWnd: config.FastRtxWnd, + cwndCAStep: config.CwndCAStep, // These two max values have us not need to follow // 5.1.1 where this peer may be incapable of supporting @@ -803,9 +817,13 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt if a.willRetransmitFast { a.willRetransmitFast = false - toFastRetrans := []chunk{} + toFastRetrans := []*chunkPayloadData{} fastRetransSize := commonHeaderSize + fastRetransWnd := a.MTU() + if fastRetransWnd < a.fastRtxWnd { + fastRetransWnd = a.fastRtxWnd + } for i := 0; ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { @@ -831,7 +849,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) - if a.MTU() < fastRetransSize+dataChunkSize { + if fastRetransWnd < fastRetransSize+dataChunkSize { break } @@ -845,10 +863,12 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt } if len(toFastRetrans) > 0 { - raw, err := a.marshalPacket(a.createPacket(toFastRetrans)) - if err != nil { - a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) - } else { + for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { + raw, err := a.marshalPacket(p) + if err != nil { + a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) + continue + } rawPackets = append(rawPackets, raw) } } @@ -1115,6 +1135,9 @@ func (a *Association) CWND() uint32 { } func (a *Association) setCWND(cwnd uint32) { + if cwnd < a.minCwnd { + cwnd = a.minCwnd + } atomic.StoreUint32(&a.cwnd, cwnd) } @@ -1720,7 +1743,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.CWND() - a.setCWND(a.CWND() + a.MTU()) + step := a.MTU() + if step < a.cwndCAStep { + step = a.cwndCAStep + } + a.setCWND(a.CWND() + step) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) } diff --git a/association_test.go b/association_test.go index a8611318..860741b3 100644 --- a/association_test.go +++ b/association_test.go @@ -2735,6 +2735,10 @@ func (d *udpDiscardReader) Read(b []byte) (n int, err error) { } func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) { + return createAssociationPairWithConfig(udpConn1, udpConn2, Config{}) +} + +func createAssociationPairWithConfig(udpConn1 net.Conn, udpConn2 net.Conn, config Config) (*Association, *Association, error) { loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan interface{}) @@ -2744,10 +2748,10 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, defer cancel() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn1, - LoggerFactory: loggerFactory, - }) + cfg := config + cfg.NetConn = udpConn1 + cfg.LoggerFactory = loggerFactory + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a1Chan <- err2 } else { @@ -2756,11 +2760,13 @@ func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, }() go func() { - a, err2 := createClientWithContext(ctx, Config{ - NetConn: udpConn2, - LoggerFactory: loggerFactory, - MaxReceiveBufferSize: 100_000, - }) + cfg := config + cfg.NetConn = udpConn2 + cfg.LoggerFactory = loggerFactory + if cfg.MaxReceiveBufferSize == 0 { + cfg.MaxReceiveBufferSize = 100_000 + } + a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a2Chan <- err2 } else { @@ -2880,6 +2886,81 @@ func TestAssociationReceiveWindow(t *testing.T) { cancel() } +func TestAssociation_CongestionParameters(t *testing.T) { + udp1, udp2 := createUDPConnPair() + a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000, CwndCAStep: 14000, MaxReceiveBufferSize: 1500}) + require.NoError(t, err) + defer noErrorClose(t, a2.Close) + defer noErrorClose(t, a1.Close) + s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) + require.NoError(t, err) + defer noErrorClose(t, s1.Close) + _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) + require.NoError(t, err) + s2, err := a2.AcceptStream() + require.NoError(t, err) + s2.Close() + + a1.rtoMgr.setRTO(1000, true) + // ack the hello packet + time.Sleep(1 * time.Second) + + require.Equal(t, uint32(a1.minCwnd), a1.CWND()) + + var shouldDrop atomic.Bool + var dropCounter atomic.Uint32 + dbConn1 := udp1.(*dumbConn2) + dbConn1.remoteInboundHandler = func(packet []byte) { + if !shouldDrop.Load() { + udp2.(*dumbConn2).inboundHandler(packet) + } else { + dropCounter.Add(1) + } + } + + shouldDrop.Store(true) + // send packets and dropped + buf := make([]byte, 1000) + for i := 0; i < 10; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + } + + require.Eventually(t, func() bool { return dropCounter.Load() >= 10 }, 5*time.Second, 10*time.Millisecond) + // send packets to trigger fast retransmit + shouldDrop.Store(false) + + require.Zero(t, a1.stats.getNumFastRetrans()) + require.False(t, a1.inFastRecovery) + + // wait SACK + sackCh := make(chan []byte, 1) + udp2.(*dumbConn2).remoteInboundHandler = func(buf []byte) { + p := &packet{} + require.NoError(t, p.unmarshal(true, buf)) + for _, c := range p.chunks { + if _, ok := c.(*chunkSelectiveAck); ok { + select { + case sackCh <- buf: + default: + } + return + } + } + } + // wait sack to trigger fast retransmit + for i := 0; i < 3; i++ { + _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) + require.NoError(t, err) + udp1.(*dumbConn2).inboundHandler(<-sackCh) + } + require.Greater(t, a1.minCwnd+a1.cwndCAStep, a1.CWND()) + + // fast retransmit and new sack sent + require.Eventually(t, func() bool { return a1.inFastRecovery }, 5*time.Second, 10*time.Millisecond) + require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) +} + func TestAssociationMaxTSNOffset(t *testing.T) { udp1, udp2 := createUDPConnPair() // a1 is the association used for sending data