Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inconsistencies in timer implementations #338

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions ack_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type ackTimerObserver interface {
onAckTimeout()
}

type ackTimerState int
type ackTimerState uint8

const (
ackTimerStopped ackTimerState = iota
Expand All @@ -28,10 +28,11 @@ const (

// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1
type ackTimer struct {
timer *time.Timer
observer ackTimerObserver
mutex sync.RWMutex
mutex sync.Mutex
state ackTimerState
timer *time.Timer
pending uint8
}

// newAckTimer creates a new acknowledgement timer used to enable delayed ack.
Expand All @@ -44,7 +45,7 @@ func newAckTimer(observer ackTimerObserver) *ackTimer {

func (t *ackTimer) timeout() {
t.mutex.Lock()
if t.state == ackTimerStarted {
if t.pending--; t.pending == 0 && t.state == ackTimerStarted {
t.state = ackTimerStopped
defer t.observer.onAckTimeout()
}
Expand All @@ -62,6 +63,7 @@ func (t *ackTimer) start() bool {
}

t.state = ackTimerStarted
t.pending++
t.timer.Reset(ackInterval)
return true
}
Expand All @@ -73,7 +75,9 @@ func (t *ackTimer) stop() {
defer t.mutex.Unlock()

if t.state == ackTimerStarted {
t.timer.Stop()
if t.timer.Stop() {
t.pending--
}
t.state = ackTimerStopped
}
}
Expand All @@ -84,17 +88,17 @@ func (t *ackTimer) close() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.state == ackTimerStarted {
t.timer.Stop()
if t.state == ackTimerStarted && t.timer.Stop() {
t.pending--
}
t.state = ackTimerClosed
}

// isRunning tests if the timer is running.
// Debug purpose only
func (t *ackTimer) isRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
t.mutex.Lock()
defer t.mutex.Unlock()

return t.state == ackTimerStarted
}
113 changes: 55 additions & 58 deletions rtx_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,28 @@ type rtxTimerObserver interface {
onRetransmissionFailure(timerID int)
}

type rtxTimerState uint8

const (
rtxTimerStopped rtxTimerState = iota
rtxTimerStarted
rtxTimerClosed
)

// rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1
type rtxTimer struct {
id int
timer *time.Timer
observer rtxTimerObserver
id int
maxRetrans uint
stopFunc stopTimerLoop
closed bool
mutex sync.RWMutex
rtoMax float64
mutex sync.Mutex
rto float64
nRtos uint
state rtxTimerState
pending uint8
}

type stopTimerLoop func()

// newRTXTimer creates a new retransmission timer.
// if maxRetrans is set to 0, it will keep retransmitting until stop() is called.
// (it will never make onRetransmissionFailure() callback.
Expand All @@ -146,62 +155,50 @@ func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint,
if timer.rtoMax == 0 {
timer.rtoMax = defaultRTOMax
}
timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout)
timer.timer.Stop()
return &timer
}

func (t *rtxTimer) calculateNextTimeout() time.Duration {
timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax)
return time.Duration(timeout) * time.Millisecond
}

func (t *rtxTimer) timeout() {
t.mutex.Lock()
if t.pending--; t.pending == 0 && t.state == rtxTimerStarted {
if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans {
t.timer.Reset(t.calculateNextTimeout())
t.pending++
defer t.observer.onRetransmissionTimeout(t.id, t.nRtos)
} else {
t.state = rtxTimerStopped
defer t.observer.onRetransmissionFailure(t.id)
}
}
t.mutex.Unlock()
}

// start starts the timer.
func (t *rtxTimer) start(rto float64) bool {
t.mutex.Lock()
defer t.mutex.Unlock()

// this timer is already closed
if t.closed {
return false
}

// this is a noop if the timer is always running
if t.stopFunc != nil {
// this timer is already closed or aleady running
if t.state != rtxTimerStopped {
return false
}

// Note: rto value is intentionally not capped by RTO.Min to allow
// fast timeout for the tests. Non-test code should pass in the
// rto generated by rtoManager getRTO() method which caps the
// value at RTO.Min or at RTO.Max.
var nRtos uint

cancelCh := make(chan struct{})

go func() {
canceling := false

timer := time.NewTimer(math.MaxInt64)
timer.Stop()

for !canceling {
timeout := calculateNextTimeout(rto, nRtos, t.rtoMax)
timer.Reset(time.Duration(timeout) * time.Millisecond)

select {
case <-timer.C:
nRtos++
if t.maxRetrans == 0 || nRtos <= t.maxRetrans {
t.observer.onRetransmissionTimeout(t.id, nRtos)
} else {
t.stop()
t.observer.onRetransmissionFailure(t.id)
}
case <-cancelCh:
canceling = true
timer.Stop()
}
}
}()

t.stopFunc = func() {
close(cancelCh)
}

t.rto = rto
t.nRtos = 0
t.state = rtxTimerStarted
t.pending++
t.timer.Reset(t.calculateNextTimeout())
return true
}

Expand All @@ -210,9 +207,11 @@ func (t *rtxTimer) stop() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == rtxTimerStarted {
if t.timer.Stop() {
t.pending--
}
t.state = rtxTimerStopped
}
}

Expand All @@ -222,21 +221,19 @@ func (t *rtxTimer) close() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == rtxTimerStarted && t.timer.Stop() {
t.pending--
}

t.closed = true
t.state = rtxTimerClosed
}

// isRunning tests if the timer is running.
// Debug purpose only
func (t *rtxTimer) isRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
t.mutex.Lock()
defer t.mutex.Unlock()

return (t.stopFunc != nil)
return t.state == rtxTimerStarted
}

func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 {
Expand Down
Loading