diff --git a/ack_timer.go b/ack_timer.go index 3d9b43e0..4d50b319 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -4,6 +4,7 @@ package sctp import ( + "math" "sync" "time" ) @@ -17,22 +18,32 @@ type ackTimerObserver interface { onAckTimeout() } +type ackTimerState int + +const ( + ackTimerStopped ackTimerState = iota + ackTimerStarted + ackTimerClosed +) + // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { observer ackTimerObserver - interval time.Duration - stopFunc stopAckTimerLoop - closed bool mutex sync.RWMutex + state ackTimerState + timer *time.Timer + cancel chan struct{} } -type stopAckTimerLoop func() - // newAckTimer creates a new acknowledgement timer used to enable delayed ack. func newAckTimer(observer ackTimerObserver) *ackTimer { + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + return &ackTimer{ observer: observer, - interval: ackInterval, + timer: timer, + cancel: make(chan struct{}, 1), } } @@ -41,34 +52,41 @@ func (t *ackTimer) start() bool { t.mutex.Lock() defer t.mutex.Unlock() - // this timer is already closed - if t.closed { + // this timer is already closed or already running + if t.state != ackTimerStopped { return false } + t.state = ackTimerStarted - // this is a noop if the timer is already running - if t.stopFunc != nil { - return false + select { + case <-t.cancel: + return true + default: } - cancelCh := make(chan struct{}) - go func() { - timer := time.NewTimer(t.interval) + t.timer.Reset(ackInterval) select { - case <-timer.C: - t.stop() - t.observer.onAckTimeout() - case <-cancelCh: - timer.Stop() + case <-t.timer.C: + t.mutex.Lock() + switch t.state { + case ackTimerStopped: + // if the timer is already stopped empty the cancel channel + <-t.cancel + case ackTimerStarted: + t.state = ackTimerStopped + defer t.observer.onAckTimeout() + case ackTimerClosed: + } + t.mutex.Unlock() + case <-t.cancel: + if !t.timer.Stop() { + <-t.timer.C + } } }() - t.stopFunc = func() { - close(cancelCh) - } - return true } @@ -78,9 +96,9 @@ func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.cancel <- struct{}{} + t.state = ackTimerStopped } } @@ -90,12 +108,10 @@ func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.cancel <- struct{}{} } - - t.closed = true + t.state = ackTimerClosed } // isRunning tests if the timer is running. @@ -104,5 +120,5 @@ func (t *ackTimer) isRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() - return (t.stopFunc != nil) + return t.state == ackTimerStarted } diff --git a/ack_timer_test.go b/ack_timer_test.go index 145b6cde..575b3ecc 100644 --- a/ack_timer_test.go +++ b/ack_timer_test.go @@ -70,14 +70,16 @@ func TestAckTimer(t *testing.T) { }, }) - // should start ok - ok := rt.start() - assert.True(t, ok, "start() should succeed") - assert.True(t, rt.isRunning(), "should be running") + for i := 0; i < 2; i++ { + // should start ok + ok := rt.start() + assert.True(t, ok, "start() should succeed") + assert.True(t, rt.isRunning(), "should be running") - // stop immedidately - rt.stop() - assert.False(t, rt.isRunning(), "should not be running") + // stop immedidately + rt.stop() + assert.False(t, rt.isRunning(), "should not be running") + } // Sleep more than 200msec of interval to test if it never times out time.Sleep(ackInterval + 50*time.Millisecond) @@ -86,7 +88,7 @@ func TestAckTimer(t *testing.T) { "should not be timed out (actual: %d)", atomic.LoadUint32(&nCbs)) // can start again - ok = rt.start() + ok := rt.start() assert.True(t, ok, "start() should succeed again") assert.True(t, rt.isRunning(), "should be running") diff --git a/rtx_timer.go b/rtx_timer.go index 42848abc..354825b5 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -175,9 +175,12 @@ func (t *rtxTimer) start(rto float64) bool { go func() { canceling := false + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + for !canceling { timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) - timer := time.NewTimer(time.Duration(timeout) * time.Millisecond) + timer.Reset(time.Duration(timeout) * time.Millisecond) select { case <-timer.C: