diff --git a/apps/cnquery/cmd/plugin.go b/apps/cnquery/cmd/plugin.go index 3711f398e7..5bfdb3ac06 100644 --- a/apps/cnquery/cmd/plugin.go +++ b/apps/cnquery/cmd/plugin.go @@ -91,15 +91,15 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru assetList := runtime.Provider.Connection.Inventory.Spec.Assets log.Debug().Msgf("resolved %d assets", len(assetList)) - filteredAssets := []*inventory.Asset{} + assetCandidates := []*inventory.Asset{} if len(assetList) > 1 && conf.PlatformId != "" { filteredAsset, err := filterAssetByPlatformID(assetList, conf.PlatformId) if err != nil { return err } - filteredAssets = append(filteredAssets, filteredAsset) + assetCandidates = append(assetCandidates, filteredAsset) } else { - filteredAssets = assetList + assetCandidates = assetList } if conf.Format == "json" { @@ -117,50 +117,13 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru } } - for _, asset := range filteredAssets { - // If the assets have platform IDs, then we have already connected to them via the - // current provider. - if len(asset.PlatformIds) > 0 { - continue - } - - // Make sure the provider for the asset is present - if err := runtime.DetectProvider(asset); err != nil { - return err - } - - err := runtime.Connect(&pp.ConnectReq{ - Features: config.Features, - Asset: asset, - Upstream: upstreamConfig, - }) - if err != nil { - return err - } - } - - // TODO: filter unique assets by platform ID - uniqueAssets := []*inventory.Asset{} - platformIds := map[string]struct{}{} - for _, asset := range filteredAssets { - found := false - for _, platformId := range asset.PlatformIds { - if _, ok := platformIds[platformId]; ok { - found = true - } - } - if found { - continue - } - - uniqueAssets = append(uniqueAssets, asset) - for _, platformId := range asset.PlatformIds { - platformIds[platformId] = struct{}{} - } + uniqueAssets, err := providers.ProcessAssetCandidates(runtime, assetCandidates, upstreamConfig) + if err != nil { + return err } for i := range uniqueAssets { - connectAsset := filteredAssets[i] + connectAsset := uniqueAssets[i] if err := runtime.DetectProvider(connectAsset); err != nil { return err } @@ -203,7 +166,7 @@ func (c *cnqueryPlugin) RunQuery(conf *run.RunQueryConfig, runtime *providers.Ru sh.PrintResults(code, results) } else { reporter.BundleResultsToJSON(code, results, out) - if len(filteredAssets) != i+1 { + if len(uniqueAssets) != i+1 { out.WriteString(",") } } diff --git a/apps/cnquery/cmd/shell.go b/apps/cnquery/cmd/shell.go index 9c3891f273..0f7cec97e7 100644 --- a/apps/cnquery/cmd/shell.go +++ b/apps/cnquery/cmd/shell.go @@ -99,8 +99,14 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error { log.Fatal().Err(err).Msg("could not load asset information") } - assetList := res.Inventory.Spec.Assets - log.Debug().Msgf("resolved %d assets", len(assetList)) + assetCandidates := res.Inventory.Spec.Assets + log.Debug().Msgf("resolved %d assets", len(assetCandidates)) + + assetList, err := providers.ProcessAssetCandidates(runtime, assetCandidates, conf.UpstreamConfig) + if err != nil { + log.Fatal().Err(err).Msg("could not load asset information") + } + log.Debug().Msgf("resolved %d unique assets", len(assetList)) if len(assetList) == 0 { log.Fatal().Msg("could not find an asset that we can connect to") diff --git a/providers/assets.go b/providers/assets.go new file mode 100644 index 0000000000..4a06599489 --- /dev/null +++ b/providers/assets.go @@ -0,0 +1,65 @@ +package providers + +import ( + "go.mondoo.com/cnquery/cli/config" + "go.mondoo.com/cnquery/providers-sdk/v1/inventory" + pp "go.mondoo.com/cnquery/providers-sdk/v1/plugin" + "go.mondoo.com/cnquery/providers-sdk/v1/upstream" +) + +func ProcessAssetCandidates(runtime *Runtime, assetCandidates []*inventory.Asset, upstreamConfig *upstream.UpstreamConfig) ([]*inventory.Asset, error) { + if err := detectAssets(runtime, assetCandidates, upstreamConfig); err != nil { + return nil, err + } + + return filterUniqueAssets(assetCandidates), nil +} + +// detectAssets connects to all assets that do not have a platform ID yet +func detectAssets(runtime *Runtime, assetCandidates []*inventory.Asset, upstreamConfig *upstream.UpstreamConfig) error { + for _, asset := range assetCandidates { + // If the assets have platform IDs, then we have already connected to them via the + // current provider. + if len(asset.PlatformIds) > 0 { + continue + } + + // Make sure the provider for the asset is present + if err := runtime.DetectProvider(asset); err != nil { + return err + } + + err := runtime.Connect(&pp.ConnectReq{ + Features: config.Features, + Asset: asset, + Upstream: upstreamConfig, + }) + if err != nil { + return err + } + } + return nil +} + +// filterUniqueAssets filters assets with duplicate platform IDs +func filterUniqueAssets(assetCandidates []*inventory.Asset) []*inventory.Asset { + uniqueAssets := []*inventory.Asset{} + platformIds := map[string]struct{}{} + for _, asset := range assetCandidates { + found := false + for _, platformId := range asset.PlatformIds { + if _, ok := platformIds[platformId]; ok { + found = true + } + } + if found { + continue + } + + uniqueAssets = append(uniqueAssets, asset) + for _, platformId := range asset.PlatformIds { + platformIds[platformId] = struct{}{} + } + } + return uniqueAssets +}