diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go index 9857fe5..9d4f20b 100644 --- a/errgroup/errgroup.go +++ b/errgroup/errgroup.go @@ -8,6 +8,7 @@ package errgroup import ( "context" + "fmt" "sync" ) @@ -51,16 +52,23 @@ func (g *Group) Wait() error { func (g *Group) Go(f func() error) { g.wg.Add(1) + var err error go func() { - defer g.wg.Done() + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("errgroup: recover from %+v", e) + } - if err := f(); err != nil { - g.errOnce.Do(func() { - g.err = err - if g.cancel != nil { - g.cancel() - } - }) - } + if err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + g.wg.Done() + }() + err = f() }() } diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go index 5a0b9cb..31bafe3 100644 --- a/errgroup/errgroup_test.go +++ b/errgroup/errgroup_test.go @@ -174,3 +174,17 @@ func TestWithContext(t *testing.T) { } } } + +func TestGroup_panic(t *testing.T) { + g, _ := errgroup.WithContext(context.Background()) + + g.Go(func() error { + panic("Ops!!!") + }) + + if err := g.Wait(); err == nil { + t.Errorf("after %T.Go(func() error { panic(\"Ops!!!\") })\n"+ + "g.Wait() = %v; want %v", + g, err, "a non-nil error") + } +}