Skip to content

Commit

Permalink
🧹 Reduce the number of times we try to get platform ident
Browse files Browse the repository at this point in the history
 We were getting platform identity information multiple times. We would
 connect, then see the asset's platform not set, and redo it.
  • Loading branch information
jaym committed Apr 2, 2024
1 parent 48f7f43 commit 5b90e4b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 25 deletions.
11 changes: 2 additions & 9 deletions providers-sdk/v1/sysinfo/sysinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
package sysinfo

import (
"errors"

"github.com/rs/zerolog/log"

"go.mondoo.com/cnquery/v10"
"go.mondoo.com/cnquery/v10/cli/execruntime"
"go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
"go.mondoo.com/cnquery/v10/providers/os/connection/local"
"go.mondoo.com/cnquery/v10/providers/os/detector"
"go.mondoo.com/cnquery/v10/providers/os/id"
"go.mondoo.com/cnquery/v10/providers/os/id/hostname"
"go.mondoo.com/cnquery/v10/providers/os/resources/networkinterface"
Expand Down Expand Up @@ -47,18 +44,14 @@ func Get() (*SystemInfo, error) {
Type: "local",
}, &asset)

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, platform, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
if len(fingerprint.PlatformIDs) > 0 {
sysInfo.PlatformId = fingerprint.PlatformIDs[0]
}
}

var ok bool
sysInfo.Platform, ok = detector.DetectOS(conn)
if !ok {
return nil, errors.New("failed to detect the OS")
}
sysInfo.Platform = platform

sysInfo.Hostname, _ = hostname.Hostname(conn, sysInfo.Platform)

Expand Down
8 changes: 4 additions & 4 deletions providers/os/id/platform.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ type PlatformInfo struct {
RelatedPlatformIDs []string
}

func IdentifyPlatform(conn shared.Connection, p *inventory.Platform, idDetectors []string) (*PlatformFingerprint, error) {
func IdentifyPlatform(conn shared.Connection, p *inventory.Platform, idDetectors []string) (*PlatformFingerprint, *inventory.Platform, error) {
var ok bool
if p == nil {
p, ok = detector.DetectOS(conn)
if !ok {
return nil, errors.New("cannot detect os")
return nil, nil, errors.New("cannot detect os")
}
}

Expand Down Expand Up @@ -99,7 +99,7 @@ func IdentifyPlatform(conn shared.Connection, p *inventory.Platform, idDetectors

// if we found zero platform ids something went wrong
if len(platformIds) == 0 {
return nil, errors.New("could not determine a platform identifier")
return nil, nil, errors.New("could not determine a platform identifier")
}

fingerprint.PlatformIDs = platformIds
Expand All @@ -111,7 +111,7 @@ func IdentifyPlatform(conn shared.Connection, p *inventory.Platform, idDetectors
}

log.Debug().Interface("id-detector", idDetectors).Strs("platform-ids", platformIds).Msg("detected platform ids")
return &fingerprint, nil
return &fingerprint, p, nil
}

func GatherNameForPlatformId(id string) string {
Expand Down
23 changes: 23 additions & 0 deletions providers/os/provider/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
"go.mondoo.com/cnquery/v10/providers/os/connection/shared"
"go.mondoo.com/cnquery/v10/providers/os/detector"
"go.mondoo.com/cnquery/v10/providers/os/id"
"go.mondoo.com/cnquery/v10/providers/os/id/aws"
"go.mondoo.com/cnquery/v10/providers/os/id/azure"
"go.mondoo.com/cnquery/v10/providers/os/id/gcp"
Expand Down Expand Up @@ -131,3 +132,25 @@ func relatedIds2assets(ids []string) []*inventory.Asset {
}
return res
}

func appendRelatedAssetsFromFingerprint(f *id.PlatformFingerprint, a *inventory.Asset) {
if f == nil || len(f.RelatedAssets) == 0 {
return
}
included := make(map[string]struct{}, len(a.RelatedAssets))
for i := range a.RelatedAssets {
included[a.RelatedAssets[i].Id] = struct{}{}
}
for _, ra := range f.RelatedAssets {
shouldAdd := true
for _, pId := range ra.PlatformIDs {
if _, ok := included[pId]; ok {
shouldAdd = false
break
}
}
if shouldAdd {
a.RelatedAssets = append(a.RelatedAssets, &inventory.Asset{Id: ra.PlatformIDs[0]})
}
}
}
23 changes: 17 additions & 6 deletions providers/os/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,13 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
case shared.Type_Local.String():
conn = local.NewConnection(connId, conf, asset)

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
asset.Name = fingerprint.Name
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
appendRelatedAssetsFromFingerprint(fingerprint, asset)
}

case shared.Type_SSH.String():
Expand All @@ -314,13 +316,15 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
return nil, err
}

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
if conn.Asset().Connections[0].Runtime != "vagrant" {
asset.Name = fingerprint.Name
}
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
appendRelatedAssetsFromFingerprint(fingerprint, asset)
}

case shared.Type_Winrm.String():
Expand All @@ -329,11 +333,13 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
return nil, err
}

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
asset.Name = fingerprint.Name
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
appendRelatedAssetsFromFingerprint(fingerprint, asset)
}

case shared.Type_Tar.String():
Expand All @@ -342,11 +348,13 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
return nil, err
}

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
asset.Name = fingerprint.Name
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
appendRelatedAssetsFromFingerprint(fingerprint, asset)
}

case shared.Type_DockerSnapshot.String():
Expand All @@ -355,11 +363,13 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
return nil, err
}

fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
asset.Name = fingerprint.Name
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
appendRelatedAssetsFromFingerprint(fingerprint, asset)
}

case shared.Type_Vagrant.String():
Expand Down Expand Up @@ -394,11 +404,12 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba
// This is a workaround to set Google COS platform IDs when scanned from inside k8s
pID, err := conn.(*fs.FileSystemConnection).Identifier()
if err != nil {
fingerprint, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
fingerprint, p, err := id.IdentifyPlatform(conn, asset.Platform, asset.IdDetector)
if err == nil {
asset.Name = fingerprint.Name
asset.PlatformIds = fingerprint.PlatformIDs
asset.IdDetector = fingerprint.ActiveIdDetectors
asset.Platform = p
}
} else {
// In this case asset.Name should already be set via the inventory
Expand Down
10 changes: 4 additions & 6 deletions providers/os/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ func TestLocalConnectionIdDetectors(t *testing.T) {
require.Contains(t, connectResp.Asset.IdDetector, ids.IdDetector_Hostname)
require.Contains(t, connectResp.Asset.IdDetector, ids.IdDetector_CloudDetect)
require.NotContains(t, connectResp.Asset.IdDetector, ids.IdDetector_SshHostkey)
// here we have the hostname twice, as platformid and stand alone
// This get's cleaned up later in the code
// FIXME: this should only be 1
require.Len(t, connectResp.Asset.PlatformIds, 2)

require.Len(t, connectResp.Asset.PlatformIds, 1)

shutdownconnectResp, err := srv.Shutdown(&plugin.ShutdownReq{})
require.NoError(t, err)
Expand Down Expand Up @@ -106,8 +104,8 @@ func TestLocalConnectionIdDetectors_DelayedDiscovery(t *testing.T) {
require.Contains(t, connectResp.Asset.IdDetector, ids.IdDetector_Hostname)
require.Contains(t, connectResp.Asset.IdDetector, ids.IdDetector_CloudDetect)
require.NotContains(t, connectResp.Asset.IdDetector, ids.IdDetector_SshHostkey)
// Now the platformIDs are cleaned up
require.Len(t, connectResp.Asset.PlatformIds, 2)

require.Len(t, connectResp.Asset.PlatformIds, 1)
// Verify the platform is set
require.NotNil(t, connectResp.Asset.Platform)

Expand Down

0 comments on commit 5b90e4b

Please sign in to comment.