From 984772253077b6dc7446f15613f1a05a16385826 Mon Sep 17 00:00:00 2001 From: Sean DuBois Date: Sun, 20 Aug 2023 20:33:16 -0400 Subject: [PATCH] Fix races from 3fba82395 Found in pion/webrtc#2112 --- association.go | 68 ++++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/association.go b/association.go index 04cac903..ce3725ec 100644 --- a/association.go +++ b/association.go @@ -323,9 +323,9 @@ func createAssociation(config Config) *Association { // o The initial cwnd before DATA transmission or after a sufficiently // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 // bytes)). - a.cwnd = min32(4*a.mtu, max32(2*a.mtu, 4380)) + a.setCWND(min32(4*a.MTU(), max32(2*a.MTU(), 4380))) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", - a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) a.srtt.Store(float64(0)) a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans) @@ -747,7 +747,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) - if a.mtu < fastRetransSize+dataChunkSize { + if a.MTU() < fastRetransSize+dataChunkSize { break } @@ -1024,11 +1024,19 @@ func (a *Association) CWND() uint32 { return atomic.LoadUint32(&a.cwnd) } +func (a *Association) setCWND(cwnd uint32) { + atomic.StoreUint32(&a.cwnd, cwnd) +} + // RWND returns the association's current receiver window (rwnd) func (a *Association) RWND() uint32 { return atomic.LoadUint32(&a.rwnd) } +func (a *Association) setRWND(rwnd uint32) { + atomic.StoreUint32(&a.rwnd, rwnd) +} + // SRTT returns the latest smoothed round-trip time (srrt) func (a *Association) SRTT() float64 { return a.srtt.Load().(float64) //nolint:forcetypeassert @@ -1144,16 +1152,16 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { return nil } - a.rwnd = i.advertisedReceiverWindowCredit - a.log.Debugf("[%s] initial rwnd=%d", a.name, a.rwnd) + a.setRWND(i.advertisedReceiverWindowCredit) + a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) // RFC 4690 Sec 7.2.1 // o The initial value of ssthresh MAY be arbitrarily high (for // example, implementations MAY use the size of the receiver // advertised window). - a.ssthresh = a.rwnd + a.ssthresh = a.RWND() a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", - a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) a.t1Init.stop() a.storedInit = nil @@ -1542,7 +1550,7 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { } // Update congestion control parameters - if a.cwnd <= a.ssthresh { + if a.CWND() <= a.ssthresh { // RFC 4096, sec 7.2.1. Slow-Start // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST // use the slow-start algorithm to increase cwnd only if the current @@ -1556,13 +1564,13 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // path MTU. if !a.inFastRecovery && a.pendingQueue.size() > 0 { - a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) // TCP way - // a.cwnd += min32(uint32(totalBytesAcked), a.mtu) // SCTP way (slow) + a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND())) + // a.cwnd += min32(uint32(totalBytesAcked), a.MTU()) // SCTP way (slow) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", - a.name, a.cwnd, a.ssthresh, totalBytesAcked) + a.name, a.CWND(), a.ssthresh, totalBytesAcked) } else { a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d", - a.name, a.cwnd, a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) + a.name, a.CWND(), a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) } } else { // RFC 4096, sec 7.2.2. Congestion Avoidance @@ -1578,11 +1586,11 @@ func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // of data outstanding (i.e., before arrival of the SACK, flight size // was greater than or equal to cwnd), increase cwnd by MTU, and // reset partial_bytes_acked to (partial_bytes_acked - cwnd). - if a.partialBytesAcked >= a.cwnd && a.pendingQueue.size() > 0 { - a.partialBytesAcked -= a.cwnd - a.cwnd += a.mtu + if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { + a.partialBytesAcked -= a.CWND() + a.setCWND(a.CWND() + a.MTU()) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", - a.name, a.cwnd, a.ssthresh, totalBytesAcked) + a.name, a.CWND(), a.ssthresh, totalBytesAcked) } } } @@ -1622,13 +1630,13 @@ func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cum // last sent, according to the formula described in Section 7.2.3. a.inFastRecovery = true a.fastRecoverExitPoint = htna - a.ssthresh = max32(a.cwnd/2, 4*a.mtu) - a.cwnd = a.ssthresh + a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) + a.setCWND(a.ssthresh) a.partialBytesAcked = 0 a.willRetransmitFast = true a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)", - a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) } } } @@ -1710,9 +1718,9 @@ func (a *Association) handleSack(d *chunkSelectiveAck) error { // bytes acked were already subtracted by markAsAcked() method bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) if bytesOutstanding >= d.advertisedReceiverWindowCredit { - a.rwnd = 0 + a.setRWND(0) } else { - a.rwnd = d.advertisedReceiverWindowCredit - bytesOutstanding + a.setRWND(d.advertisedReceiverWindowCredit - bytesOutstanding) } err = a.processFastRetransmission(d.cumulativeTSNAck, htna, cumTSNAckPointAdvanced) @@ -2120,7 +2128,7 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 continue } - if uint32(a.inflightQueue.getNumBytes())+dataLen > a.cwnd { + if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() { break // would exceeds cwnd } @@ -2128,7 +2136,7 @@ func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint1 break // no more rwnd } - a.rwnd -= dataLen + a.setRWND(a.RWND() - dataLen) a.movePendingDataChunkToInflightQueue(c) chunks = append(chunks, c) @@ -2163,7 +2171,7 @@ func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) [] // single packet. Furthermore, DATA chunks being retransmitted MAY be // bundled with new DATA chunks, as long as the resulting packet size // does not exceed the path MTU. - if bytesInPacket+len(c.userData) > int(a.mtu) { + if bytesInPacket+len(c.userData) > int(a.MTU()) { packets = append(packets, a.createPacket(chunksToSend)) chunksToSend = []chunk{} bytesInPacket = int(commonHeaderSize) @@ -2240,7 +2248,7 @@ func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { // that are not acked or abandoned yet. // The caller should hold the lock. func (a *Association) getDataPacketsToRetransmit() []*packet { - awnd := min32(a.cwnd, a.rwnd) + awnd := min32(a.CWND(), a.RWND()) chunks := []*chunkPayloadData{} var bytesToSend int var done bool @@ -2255,7 +2263,7 @@ func (a *Association) getDataPacketsToRetransmit() []*packet { continue } - if i == 0 && int(a.rwnd) < len(c.userData) { + if i == 0 && int(a.RWND()) < len(c.userData) { // Send it as a zero window probe done = true } else if bytesToSend+len(c.userData) > int(awnd) { @@ -2460,10 +2468,10 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { // ssthresh = max(cwnd/2, 4*MTU) // cwnd = 1*MTU - a.ssthresh = max32(a.cwnd/2, 4*a.mtu) - a.cwnd = a.mtu + a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) + a.setCWND(a.MTU()) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)", - a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) + a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) // RFC 3758 sec 3.5 // A5) Any time the T3-rtx timer expires, on any destination, the sender @@ -2488,7 +2496,7 @@ func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { } } - a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.cwnd, a.ssthresh) + a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.CWND(), a.ssthresh) /* a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint)