Skip to content

Commit

Permalink
🐛 Allow for provider shutdown to take longer than a heartbeat. (#3001)
Browse files Browse the repository at this point in the history
* 🐛 Allow for Shutdown to take longer than a heartbeat.

Signed-off-by: Preslav <[email protected]>

* Add a test for provider shutdown.

Signed-off-by: Preslav <[email protected]>

* Make grace period and interval configurable on the provider.

Signed-off-by: Preslav <[email protected]>

* Add extra assertions for isCloseOrShutdown.

Signed-off-by: Preslav <[email protected]>

---------

Signed-off-by: Preslav <[email protected]>
  • Loading branch information
preslavgerchev authored Jan 12, 2024
1 parent b0d48e4 commit 8266a63
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
40 changes: 24 additions & 16 deletions providers/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,28 @@ type RunningProvider struct {
// isShutdown is only used once during provider shutdown
isShutdown bool
// provider errors which are evaluated and printed during shutdown of the provider
err error
lock sync.Mutex
err error
lock sync.Mutex
shutdownLock sync.Mutex
interval time.Duration
gracePeriod time.Duration
}

// initialize the heartbeat with the provider
func (p *RunningProvider) heartbeat() error {
interval := 2 * time.Second
gracePeriod := 3 * time.Second

if err := p.doOneHeartbeat(interval + gracePeriod); err != nil {
if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil {
p.Shutdown()
return err
}

go func() {
for !p.isCloseOrShutdown() {
if err := p.doOneHeartbeat(interval + gracePeriod); err != nil {
if err := p.doOneHeartbeat(p.interval + p.gracePeriod); err != nil {
p.Shutdown()
break
}

time.Sleep(interval)
time.Sleep(p.interval)
}
}()

Expand All @@ -104,8 +104,8 @@ func (p *RunningProvider) doOneHeartbeat(t time.Duration) error {
}

func (p *RunningProvider) isCloseOrShutdown() bool {
p.lock.Lock()
defer p.lock.Unlock()
p.shutdownLock.Lock()
defer p.shutdownLock.Unlock()
return p.isClosed || p.isShutdown
}

Expand Down Expand Up @@ -136,10 +136,16 @@ func (p *RunningProvider) Shutdown() error {
if p.Client != nil {
p.Client.Kill()
}
p.shutdownLock.Lock()
p.isClosed = true
p.isShutdown = true
p.shutdownLock.Unlock()
} else {
p.shutdownLock.Lock()
p.isShutdown = true
p.shutdownLock.Unlock()
}

p.isShutdown = true
return err
}

Expand Down Expand Up @@ -229,11 +235,13 @@ func (c *coordinator) Start(id string, isEphemeral bool, update UpdateProvidersC
}

res := &RunningProvider{
Name: provider.Name,
ID: provider.ID,
Plugin: raw.(pp.ProviderPlugin),
Client: client,
Schema: provider.Schema,
Name: provider.Name,
ID: provider.ID,
Plugin: raw.(pp.ProviderPlugin),
Client: client,
Schema: provider.Schema,
interval: 2 * time.Second,
gracePeriod: 3 * time.Second,
}

if err := res.heartbeat(); err != nil {
Expand Down
59 changes: 59 additions & 0 deletions providers/providers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package providers

import (
"testing"
"time"

"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v9/providers-sdk/v1/plugin"
)

type testPlugin struct {
plugin.Service
}

func (t *testPlugin) Connect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) {
return nil, nil
}

func (t *testPlugin) MockConnect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) {
return nil, nil
}

func (t *testPlugin) ParseCLI(req *plugin.ParseCLIReq) (*plugin.ParseCLIRes, error) {
return nil, nil
}

func (t *testPlugin) Shutdown(req *plugin.ShutdownReq) (*plugin.ShutdownRes, error) {
// sleep more than the heartbeat interval to ensure that even if shutting down
// the provider can still respond to heartbeats
time.Sleep(10 * time.Second)
return &plugin.ShutdownRes{}, nil
}

func (t *testPlugin) GetData(req *plugin.DataReq) (*plugin.DataRes, error) {
return nil, nil
}

func (t *testPlugin) StoreData(req *plugin.StoreReq) (*plugin.StoreRes, error) {
return nil, nil
}

func TestProviderShutdown(t *testing.T) {
s := &RunningProvider{
Plugin: &testPlugin{},
interval: 500 * time.Millisecond,
gracePeriod: 500 * time.Millisecond,
}
err := s.heartbeat()
require.NoError(t, err)
require.False(t, s.isCloseOrShutdown())
// the shutdown here takes 10 seconds, whereas the heartbeat interval is every second.
// this means that this provider gets multiple heartbeats while shutting down
err = s.Shutdown()
require.NoError(t, err)
require.True(t, s.isCloseOrShutdown())
}

0 comments on commit 8266a63

Please sign in to comment.