diff --git a/ack_timer.go b/ack_timer.go index 3d9b43e0..7de156bb 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -4,6 +4,7 @@ package sctp import ( + "math" "sync" "time" ) @@ -17,22 +18,32 @@ type ackTimerObserver interface { onAckTimeout() } +type ackTimerState int + +const ( + ackTimerStopped ackTimerState = iota + ackTimerStarted + ackTimerClosed +) + // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { observer ackTimerObserver - interval time.Duration - stopFunc stopAckTimerLoop - closed bool mutex sync.RWMutex + state ackTimerState + timer *time.Timer + cancel chan struct{} } -type stopAckTimerLoop func() - // newAckTimer creates a new acknowledgement timer used to enable delayed ack. func newAckTimer(observer ackTimerObserver) *ackTimer { + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + return &ackTimer{ observer: observer, - interval: ackInterval, + timer: timer, + cancel: make(chan struct{}, 1), } } @@ -41,34 +52,35 @@ func (t *ackTimer) start() bool { t.mutex.Lock() defer t.mutex.Unlock() - // this timer is already closed - if t.closed { + // this timer is already closed or already running + if t.state != ackTimerStopped { return false } - - // this is a noop if the timer is already running - if t.stopFunc != nil { - return false - } - - cancelCh := make(chan struct{}) + t.state = ackTimerStarted go func() { - timer := time.NewTimer(t.interval) + t.timer.Reset(ackInterval) select { - case <-timer.C: - t.stop() - t.observer.onAckTimeout() - case <-cancelCh: - timer.Stop() + case <-t.timer.C: + t.mutex.Lock() + switch t.state { + case ackTimerStopped: + // if the timer is already stopped empty the cancel channel + <-t.cancel + case ackTimerStarted: + t.state = ackTimerStopped + defer t.observer.onAckTimeout() + case ackTimerClosed: + } + t.mutex.Unlock() + case <-t.cancel: + if !t.timer.Stop() { + <-t.timer.C + } } }() - t.stopFunc = func() { - close(cancelCh) - } - return true } @@ -78,9 +90,9 @@ func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.cancel <- struct{}{} + t.state = ackTimerStopped } } @@ -90,12 +102,10 @@ func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.cancel <- struct{}{} } - - t.closed = true + t.state = ackTimerClosed } // isRunning tests if the timer is running. @@ -104,5 +114,5 @@ func (t *ackTimer) isRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() - return (t.stopFunc != nil) + return t.state == ackTimerStarted } diff --git a/rtx_timer.go b/rtx_timer.go index 42848abc..354825b5 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -175,9 +175,12 @@ func (t *rtxTimer) start(rto float64) bool { go func() { canceling := false + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + for !canceling { timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) - timer := time.NewTimer(time.Duration(timeout) * time.Millisecond) + timer.Reset(time.Duration(timeout) * time.Millisecond) select { case <-timer.C: