diff --git a/providers/coordinator.go b/providers/coordinator.go index d31253cc5a..68611403ae 100644 --- a/providers/coordinator.go +++ b/providers/coordinator.go @@ -60,28 +60,28 @@ type RunningProvider struct { // isShutdown is only used once during provider shutdown isShutdown bool // provider errors which are evaluated and printed during shutdown of the provider - err error - lock sync.Mutex + err error + lock sync.Mutex + shutdownLock sync.Mutex + interval time.Duration + gracePeriod time.Duration } // initialize the heartbeat with the provider func (p *RunningProvider) heartbeat() error { - interval := 2 * time.Second - gracePeriod := 3 * time.Second - - if err := p.doOneHeartbeat(interval + gracePeriod); err != nil { + if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil { p.Shutdown() return err } go func() { for !p.isCloseOrShutdown() { - if err := p.doOneHeartbeat(interval + gracePeriod); err != nil { + if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil { p.Shutdown() break } - time.Sleep(interval) + time.Sleep(p.interval) } }() @@ -104,8 +104,8 @@ func (p *RunningProvider) doOneHeartbeat(t time.Duration) error { } func (p *RunningProvider) isCloseOrShutdown() bool { - p.lock.Lock() - defer p.lock.Unlock() + p.shutdownLock.Lock() + defer p.shutdownLock.Unlock() return p.isClosed || p.isShutdown } @@ -136,10 +136,16 @@ func (p *RunningProvider) Shutdown() error { if p.Client != nil { p.Client.Kill() } + p.shutdownLock.Lock() p.isClosed = true + p.isShutdown = true + p.shutdownLock.Unlock() + } else { + p.shutdownLock.Lock() + p.isShutdown = true + p.shutdownLock.Unlock() } - p.isShutdown = true return err } @@ -229,11 +235,13 @@ func (c *coordinator) Start(id string, isEphemeral bool, update UpdateProvidersC } res := &RunningProvider{ - Name: provider.Name, - ID: provider.ID, - Plugin: raw.(pp.ProviderPlugin), - Client: client, - Schema: provider.Schema, + Name: provider.Name, + ID: provider.ID, + Plugin: raw.(pp.ProviderPlugin), + Client: client, + Schema: provider.Schema, + interval: 2 * time.Second, + gracePeriod: 3 * time.Second, } if err := res.heartbeat(); err != nil { diff --git a/providers/providers_test.go b/providers/providers_test.go new file mode 100644 index 0000000000..9de2658136 --- /dev/null +++ b/providers/providers_test.go @@ -0,0 +1,59 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package providers + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.mondoo.com/cnquery/v9/providers-sdk/v1/plugin" +) + +type testPlugin struct { + plugin.Service +} + +func (t *testPlugin) Connect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) { + return nil, nil +} + +func (t *testPlugin) MockConnect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) { + return nil, nil +} + +func (t *testPlugin) ParseCLI(req *plugin.ParseCLIReq) (*plugin.ParseCLIRes, error) { + return nil, nil +} + +func (t *testPlugin) Shutdown(req *plugin.ShutdownReq) (*plugin.ShutdownRes, error) { + // sleep more than the heartbeat interval to ensure that even if shutting down + // the provider can still respond to heartbeats + time.Sleep(10 * time.Second) + return &plugin.ShutdownRes{}, nil +} + +func (t *testPlugin) GetData(req *plugin.DataReq) (*plugin.DataRes, error) { + return nil, nil +} + +func (t *testPlugin) StoreData(req *plugin.StoreReq) (*plugin.StoreRes, error) { + return nil, nil +} + +func TestProviderShutdown(t *testing.T) { + s := &RunningProvider{ + Plugin: &testPlugin{}, + interval: 500 * time.Millisecond, + gracePeriod: 500 * time.Millisecond, + } + err := s.heartbeat() + require.NoError(t, err) + require.False(t, s.isCloseOrShutdown()) + // the shutdown here takes 10 seconds, whereas the heartbeat interval is every second. + // this means that this provider gets multiple heartbeats while shutting down + err = s.Shutdown() + require.NoError(t, err) + require.True(t, s.isCloseOrShutdown()) +}