diff --git a/sizedwaitgroup.go b/sizedwaitgroup.go index 4a0685d..19faae5 100644 --- a/sizedwaitgroup.go +++ b/sizedwaitgroup.go @@ -82,3 +82,20 @@ func (s *SizedWaitGroup) Done() { func (s *SizedWaitGroup) Wait() { s.wg.Wait() } + +// Wait blocks until the SizedWaitGroup counter is zero or the context is Done. +// See sync.WaitGroup documentation for more information. +func (s *SizedWaitGroup) WaitWithContext(ctx context.Context) error { + done := make(chan struct{}) + go func() { + defer close(done) + s.Wait() + done <- struct{}{} + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + return nil +} diff --git a/sizedwaitgroup_test.go b/sizedwaitgroup_test.go index 2bf1859..53316e7 100644 --- a/sizedwaitgroup_test.go +++ b/sizedwaitgroup_test.go @@ -83,3 +83,26 @@ func TestAddWithContext(t *testing.T) { } } + +func TestWaitWithContext(t *testing.T) { + t.Run("cancelled context error is returned", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.TODO()) + + swg := New(1) + swg.Add() + cancelFunc() + + if err := swg.WaitWithContext(ctx); err != context.Canceled { + t.Fatalf("expected cancelled context: %s", err) + } + }) + t.Run("done group returns nil", func(t *testing.T) { + swg := New(1) + swg.Add() + swg.Done() + + if err := swg.WaitWithContext(context.TODO()); err != nil { + t.Fatalf("expected nil: %s", err) + } + }) +}