diff --git a/utils/hedge.go b/utils/hedge.go index 85c0e99c..6d5e0044 100644 --- a/utils/hedge.go +++ b/utils/hedge.go @@ -16,11 +16,14 @@ package utils import ( "context" + "errors" "time" "go.uber.org/multierr" ) +var ErrMaxAttemptsReached = errors.New("max attempts reached") + type HedgeParams[T any] struct { Timeout time.Duration RetryDelay time.Duration @@ -47,7 +50,7 @@ func HedgeCall[T any](ctx context.Context, params HedgeParams[T]) (v T, err erro ch <- result{value, err} } - var attempt int + var attempt, done int delay := time.NewTimer(0) defer delay.Stop() @@ -59,10 +62,18 @@ func HedgeCall[T any](ctx context.Context, params HedgeParams[T]) (v T, err erro delay.Reset(params.RetryDelay) } case res := <-ch: - if res.err == nil || params.IsRecoverable == nil || !params.IsRecoverable(res.err) { - return res.value, res.err + if res.err == nil { + return res.value, nil } + err = multierr.Append(err, res.err) + if params.IsRecoverable != nil && !params.IsRecoverable(res.err) { + return + } + if done++; done == params.MaxAttempts { + err = multierr.Append(err, ErrMaxAttemptsReached) + return + } case <-ctx.Done(): err = multierr.Append(err, ctx.Err()) return diff --git a/utils/hedge_test.go b/utils/hedge_test.go index d26e0f5d..e5b22e30 100644 --- a/utils/hedge_test.go +++ b/utils/hedge_test.go @@ -16,6 +16,7 @@ package utils import ( "context" + "errors" "testing" "time" @@ -24,18 +25,82 @@ import ( ) func TestHedgeCall(t *testing.T) { - var attempts atomic.Uint32 - res, err := HedgeCall(context.Background(), HedgeParams[uint32]{ - Timeout: 200 * time.Millisecond, - RetryDelay: 50 * time.Millisecond, - MaxAttempts: 2, - Func: func(context.Context) (uint32, error) { - n := attempts.Add(1) - time.Sleep(75 * time.Millisecond) - return n, nil - }, + t.Run("success", func(t *testing.T) { + var attempts atomic.Uint32 + res, err := HedgeCall(context.Background(), HedgeParams[uint32]{ + Timeout: 200 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxAttempts: 2, + Func: func(context.Context) (uint32, error) { + n := attempts.Add(1) + time.Sleep(75 * time.Millisecond) + return n, nil + }, + }) + require.NoError(t, err) + require.EqualValues(t, 1, res) + require.EqualValues(t, 2, attempts.Load()) + }) + + t.Run("recoverable error", func(t *testing.T) { + var recoverableErr = errors.New("recoverable") + + var attempts atomic.Uint32 + res, err := HedgeCall(context.Background(), HedgeParams[uint32]{ + Timeout: 200 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxAttempts: 2, + IsRecoverable: func(err error) bool { + return errors.Is(err, recoverableErr) + }, + Func: func(context.Context) (uint32, error) { + n := attempts.Add(1) + if n == 1 { + return n, recoverableErr + } + return n, nil + }, + }) + require.NoError(t, err) + require.EqualValues(t, 2, res) + }) + + t.Run("unrecoverable error", func(t *testing.T) { + var recoverableErr = errors.New("recoverable") + var unrecoverableErr = errors.New("unrecoverable") + + var attempts atomic.Uint32 + _, err := HedgeCall(context.Background(), HedgeParams[uint32]{ + Timeout: 200 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxAttempts: 3, + IsRecoverable: func(err error) bool { + return !errors.Is(err, unrecoverableErr) + }, + Func: func(context.Context) (uint32, error) { + n := attempts.Add(1) + if n == 1 { + return n, recoverableErr + } + return n, unrecoverableErr + }, + }) + require.ErrorIs(t, err, unrecoverableErr) + require.EqualValues(t, 2, attempts.Load()) + }) + + t.Run("max failures", func(t *testing.T) { + var attempts atomic.Uint32 + _, err := HedgeCall(context.Background(), HedgeParams[uint32]{ + Timeout: 200 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + MaxAttempts: 2, + Func: func(context.Context) (uint32, error) { + n := attempts.Add(1) + return n, errors.New("failure") + }, + }) + require.ErrorIs(t, err, ErrMaxAttemptsReached) + require.EqualValues(t, 2, attempts.Load()) }) - require.NoError(t, err) - require.EqualValues(t, 1, res) - require.EqualValues(t, 2, attempts.Load()) }