Skip to content

Commit

Permalink
Fix races from 3fba823
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean-Der committed Aug 21, 2023
1 parent d5f5196 commit 9847722
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,9 @@ func createAssociation(config Config) *Association {
// o The initial cwnd before DATA transmission or after a sufficiently
// long idle period MUST be set to min(4*MTU, max (2*MTU, 4380
// bytes)).
a.cwnd = min32(4*a.mtu, max32(2*a.mtu, 4380))
a.setCWND(min32(4*a.MTU(), max32(2*a.MTU(), 4380)))
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)",
a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes())
a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes())

a.srtt.Store(float64(0))
a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans)
Expand Down Expand Up @@ -747,7 +747,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
// packet.

dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData))
if a.mtu < fastRetransSize+dataChunkSize {
if a.MTU() < fastRetransSize+dataChunkSize {
break
}

Expand Down Expand Up @@ -1024,11 +1024,19 @@ func (a *Association) CWND() uint32 {
return atomic.LoadUint32(&a.cwnd)
}

func (a *Association) setCWND(cwnd uint32) {
atomic.StoreUint32(&a.cwnd, cwnd)
}

// RWND returns the association's current receiver window (rwnd)
func (a *Association) RWND() uint32 {
return atomic.LoadUint32(&a.rwnd)
}

func (a *Association) setRWND(rwnd uint32) {
atomic.StoreUint32(&a.rwnd, rwnd)
}

// SRTT returns the latest smoothed round-trip time (srrt)
func (a *Association) SRTT() float64 {
return a.srtt.Load().(float64) //nolint:forcetypeassert
Expand Down Expand Up @@ -1144,16 +1152,16 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
return nil
}

a.rwnd = i.advertisedReceiverWindowCredit
a.log.Debugf("[%s] initial rwnd=%d", a.name, a.rwnd)
a.setRWND(i.advertisedReceiverWindowCredit)
a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND())

// RFC 4690 Sec 7.2.1
// o The initial value of ssthresh MAY be arbitrarily high (for
// example, implementations MAY use the size of the receiver
// advertised window).
a.ssthresh = a.rwnd
a.ssthresh = a.RWND()
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)",
a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes())
a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes())

a.t1Init.stop()
a.storedInit = nil
Expand Down Expand Up @@ -1542,7 +1550,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) {
}

// Update congestion control parameters
if a.cwnd <= a.ssthresh {
if a.CWND() <= a.ssthresh {
// RFC 4096, sec 7.2.1. Slow-Start
// o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST
// use the slow-start algorithm to increase cwnd only if the current
Expand All @@ -1556,13 +1564,13 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) {
// path MTU.
if !a.inFastRecovery &&
a.pendingQueue.size() > 0 {
a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) // TCP way
// a.cwnd += min32(uint32(totalBytesAcked), a.mtu) // SCTP way (slow)
a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND()))
// a.cwnd += min32(uint32(totalBytesAcked), a.MTU()) // SCTP way (slow)
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)",
a.name, a.cwnd, a.ssthresh, totalBytesAcked)
a.name, a.CWND(), a.ssthresh, totalBytesAcked)
} else {
a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d",
a.name, a.cwnd, a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size())
a.name, a.CWND(), a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size())
}
} else {
// RFC 4096, sec 7.2.2. Congestion Avoidance
Expand All @@ -1578,11 +1586,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) {
// of data outstanding (i.e., before arrival of the SACK, flight size
// was greater than or equal to cwnd), increase cwnd by MTU, and
// reset partial_bytes_acked to (partial_bytes_acked - cwnd).
if a.partialBytesAcked >= a.cwnd && a.pendingQueue.size() > 0 {
a.partialBytesAcked -= a.cwnd
a.cwnd += a.mtu
if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 {
a.partialBytesAcked -= a.CWND()
a.setCWND(a.CWND() + a.MTU())
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)",
a.name, a.cwnd, a.ssthresh, totalBytesAcked)
a.name, a.CWND(), a.ssthresh, totalBytesAcked)
}
}
}
Expand Down Expand Up @@ -1622,13 +1630,13 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cum
// last sent, according to the formula described in Section 7.2.3.
a.inFastRecovery = true
a.fastRecoverExitPoint = htna
a.ssthresh = max32(a.cwnd/2, 4*a.mtu)
a.cwnd = a.ssthresh
a.ssthresh = max32(a.CWND()/2, 4*a.MTU())
a.setCWND(a.ssthresh)
a.partialBytesAcked = 0
a.willRetransmitFast = true

a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)",
a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes())
a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes())
}
}
}
Expand Down Expand Up @@ -1710,9 +1718,9 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error {
// bytes acked were already subtracted by markAsAcked() method
bytesOutstanding := uint32(a.inflightQueue.getNumBytes())
if bytesOutstanding >= d.advertisedReceiverWindowCredit {
a.rwnd = 0
a.setRWND(0)
} else {
a.rwnd = d.advertisedReceiverWindowCredit - bytesOutstanding
a.setRWND(d.advertisedReceiverWindowCredit - bytesOutstanding)
}

err = a.processFastRetransmission(d.cumulativeTSNAck, htna, cumTSNAckPointAdvanced)
Expand Down Expand Up @@ -2120,15 +2128,15 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1
continue
}

if uint32(a.inflightQueue.getNumBytes())+dataLen > a.cwnd {
if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() {
break // would exceeds cwnd
}

if dataLen > a.rwnd {
break // no more rwnd
}

a.rwnd -= dataLen
a.setRWND(a.RWND() - dataLen)

a.movePendingDataChunkToInflightQueue(c)
chunks = append(chunks, c)
Expand Down Expand Up @@ -2163,7 +2171,7 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []
// single packet. Furthermore, DATA chunks being retransmitted MAY be
// bundled with new DATA chunks, as long as the resulting packet size
// does not exceed the path MTU.
if bytesInPacket+len(c.userData) > int(a.mtu) {
if bytesInPacket+len(c.userData) > int(a.MTU()) {
packets = append(packets, a.createPacket(chunksToSend))
chunksToSend = []chunk{}
bytesInPacket = int(commonHeaderSize)
Expand Down Expand Up @@ -2240,7 +2248,7 @@ func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) {
// that are not acked or abandoned yet.
// The caller should hold the lock.
func (a *Association) getDataPacketsToRetransmit() []*packet {
awnd := min32(a.cwnd, a.rwnd)
awnd := min32(a.CWND(), a.RWND())
chunks := []*chunkPayloadData{}
var bytesToSend int
var done bool
Expand All @@ -2255,7 +2263,7 @@ func (a *Association) getDataPacketsToRetransmit() []*packet {
continue
}

if i == 0 && int(a.rwnd) < len(c.userData) {
if i == 0 && int(a.RWND()) < len(c.userData) {
// Send it as a zero window probe
done = true
} else if bytesToSend+len(c.userData) > int(awnd) {
Expand Down Expand Up @@ -2460,10 +2468,10 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) {
// ssthresh = max(cwnd/2, 4*MTU)
// cwnd = 1*MTU

a.ssthresh = max32(a.cwnd/2, 4*a.mtu)
a.cwnd = a.mtu
a.ssthresh = max32(a.CWND()/2, 4*a.MTU())
a.setCWND(a.MTU())
a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)",
a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes())
a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes())

// RFC 3758 sec 3.5
// A5) Any time the T3-rtx timer expires, on any destination, the sender
Expand All @@ -2488,7 +2496,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) {
}
}

a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.cwnd, a.ssthresh)
a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.CWND(), a.ssthresh)

/*
a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint)
Expand Down

0 comments on commit 9847722

Please sign in to comment.