Skip to content

Commit

Permalink
use globally unique connection IDs for providers
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Milchev <[email protected]>
  • Loading branch information
imilchev committed Mar 5, 2024
1 parent 6cb42a3 commit ee0903f
Show file tree
Hide file tree
Showing 26 changed files with 89 additions and 24 deletions.
2 changes: 1 addition & 1 deletion providers-sdk/v1/inventory/inventory.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func (cfg *Config) Clone(opts ...CloneOption) *Config {
}

clonedObject := proto.Clone(cfg).(*Config)

clonedObject.Id = 0
if cloneSettings.noDiscovery {
clonedObject.Discover = &Discovery{}
}
Expand Down
44 changes: 43 additions & 1 deletion providers-sdk/v1/plugin/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

llx "go.mondoo.com/cnquery/v10/llx"
inventory "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
)

type Service struct {
Expand All @@ -31,7 +32,46 @@ func NewService() *Service {

var heartbeatRes HeartbeatRes

func (s *Service) AddRuntime(createRuntime func(connId uint32) (*Runtime, error)) (*Runtime, error) {
// FIXME: once we move to v12, remove the conf parametrer and remove the connId from the createRuntime function.
// The connection ID will always be set before the connection call is done, so we don't need to do anything about it here.
// The parameters are needed now, only to make sure that old clients can work with new providers.
func (s *Service) AddRuntime(conf *inventory.Config, createRuntime func(connId uint32) (*Runtime, error)) (*Runtime, error) {
// FIXME: DEPRECATED, remove in v12.0 vv
// This approach is used only when old clients use new providers. We will throw it away in v12
if conf.Id == 0 {
return s.deprecatedAddRuntime(createRuntime)
}
// ^^

s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()

// If a runtime with this ID already exists, then return that
if runtime, ok := s.runtimes[conf.Id]; ok {
return runtime, nil
}

runtime, err := createRuntime(conf.Id)
if err != nil {
return nil, err
}

if runtime.Connection != nil {
if parentId := runtime.Connection.ParentID(); parentId > 0 {
parentRuntime, err := s.doGetRuntime(parentId)
if err != nil {
return nil, errors.New("parent connection " + strconv.FormatUint(uint64(parentId), 10) + " not found")
}
runtime.Resources = parentRuntime.Resources

}
}
s.runtimes[conf.Id] = runtime
return runtime, nil
}

// FIXME: DEPRECATED, remove in v12.0 vv
func (s *Service) deprecatedAddRuntime(createRuntime func(connId uint32) (*Runtime, error)) (*Runtime, error) {
s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()

Expand All @@ -57,6 +97,8 @@ func (s *Service) AddRuntime(createRuntime func(connId uint32) (*Runtime, error)
return runtime, nil
}

// ^^

func (s *Service) GetRuntime(id uint32) (*Runtime, error) {
s.runtimesLock.Lock()
defer s.runtimesLock.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion providers/arista/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewAristaConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/atlassian/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/aws/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn shared.Connection
var err error

Expand Down
2 changes: 1 addition & 1 deletion providers/azure/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn shared.AzureConnection
var err error

Expand Down
18 changes: 18 additions & 0 deletions providers/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
//go:generate mockgen -source=../providers-sdk/v1/resources/schema.go -destination=./mock_schema.go -package=providers

type ProvidersCoordinator interface {
NextConnectionId() uint32
NewRuntime() *Runtime
NewRuntimeFrom(parent *Runtime) *Runtime
RuntimeFor(asset *inventory.Asset, parent *Runtime) (*Runtime, error)
Expand Down Expand Up @@ -57,6 +58,9 @@ func newCoordinator() *coordinator {
}

type coordinator struct {
lastConnectionID uint32
connectionsLock sync.Mutex

providers Providers
runningByID map[string]*RunningProvider

Expand Down Expand Up @@ -88,6 +92,13 @@ type ProviderVersion struct {
Version string `json:"version"`
}

func (c *coordinator) NextConnectionId() uint32 {
c.connectionsLock.Lock()
defer c.connectionsLock.Unlock()
c.lastConnectionID++
return c.lastConnectionID
}

func (c *coordinator) tryProviderUpdate(provider *Provider, update UpdateProvidersConfig) (*Provider, error) {
if provider.Path == "" {
return nil, errors.New("cannot determine installation path for provider")
Expand Down Expand Up @@ -269,6 +280,13 @@ func (c *coordinator) RemoveRuntime(runtime *Runtime) {
}
}
}

// If all providers have been killed, reset the connection IDs back to 0
if len(c.runningByID) == 0 {
c.connectionsLock.Lock()
defer c.connectionsLock.Unlock()
c.lastConnectionID = 0
}
}

func (c *coordinator) GetRunningProvider(id string, update UpdateProvidersConfig) (*RunningProvider, error) {
Expand Down
2 changes: 1 addition & 1 deletion providers/core/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (s *Service) Connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
}

connectionId := defaultConnection
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(req.Asset.Connections[0], func(connId uint32) (*plugin.Runtime, error) {
connectionId = connId
var upstream *upstream.UpstreamClient
var err error
Expand Down
2 changes: 1 addition & 1 deletion providers/equinix/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewEquinixConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/gcp/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn shared.GcpConnection
var err error

Expand Down
2 changes: 1 addition & 1 deletion providers/github/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
}

asset := req.Asset
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(asset.Connections[0], func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewGithubConnection(connId, asset)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/gitlab/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewGitLabConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/google-workspace/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewGoogleWorkspaceConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/ipmi/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewIpmiConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/k8s/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn shared.Connection
var err error
if manifestContent, ok := conf.Options[shared.OPTION_IMMEMORY_CONTENT]; ok {
Expand Down
2 changes: 1 addition & 1 deletion providers/ms365/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewMs365Connection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/network/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn *connection.HostConnection

switch conf.Type {
Expand Down
2 changes: 1 addition & 1 deletion providers/oci/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewOciConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/okta/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewOktaConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/opcua/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewOpcuaConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/os/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn shared.Connection
var err error

Expand Down
5 changes: 5 additions & 0 deletions providers/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ func (r *Runtime) Connect(req *plugin.ConnectReq) error {
return errors.New("cannot connect to asset, no connection info provided")
}

// If there is no connection ID set, we need to assign one from the coordinator
if asset.Connections[0].Id == 0 {
asset.Connections[0].Id = Coordinator.NextConnectionId()
}

r.features = req.Features
callbacks := providerCallbacks{
runtime: r,
Expand Down
2 changes: 1 addition & 1 deletion providers/slack/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba

asset := req.Asset
conf := asset.Connections[0]
runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn *connection.SlackConnection
var err error

Expand Down
2 changes: 1 addition & 1 deletion providers/terraform/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
var conn *connection.Connection
var err error

Expand Down
2 changes: 1 addition & 1 deletion providers/vcd/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewVcdConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/vsphere/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
asset := req.Asset
conf := asset.Connections[0]

runtime, err := s.AddRuntime(func(connId uint32) (*plugin.Runtime, error) {
runtime, err := s.AddRuntime(conf, func(connId uint32) (*plugin.Runtime, error) {
conn, err := connection.NewVsphereConnection(connId, asset, conf)
if err != nil {
return nil, err
Expand Down

0 comments on commit ee0903f

Please sign in to comment.