From 3391deb96a4ea7ec9009e85989cc4d1c867acb20 Mon Sep 17 00:00:00 2001 From: Ivan Milchev Date: Mon, 5 Feb 2024 22:25:22 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9=20refactor=20local=20scanner=20cod?= =?UTF-8?q?e=20(#3202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor local scanner code Signed-off-by: Ivan Milchev * fix license error Signed-off-by: Ivan Milchev * expose more data structures Signed-off-by: Ivan Milchev * add tests for DiscoverAssets Signed-off-by: Ivan Milchev * refactor code for creating runtime for assets Signed-off-by: Ivan Milchev * test Batch Signed-off-by: Ivan Milchev * call SynchronizeAssets for the batch Signed-off-by: Ivan Milchev * fix error Signed-off-by: Ivan Milchev * add license header Signed-off-by: Ivan Milchev * add more tests Signed-off-by: Ivan Milchev * fix license check Signed-off-by: Ivan Milchev * make sure ci/cd test passes in github Signed-off-by: Ivan Milchev * fix comments about tests Signed-off-by: Ivan Milchev --------- Signed-off-by: Ivan Milchev --- apps/cnquery/cmd/shell.go | 16 +- explorer/scan/discovery.go | 185 ++++++++ explorer/scan/discovery_test.go | 224 ++++++++++ explorer/scan/local_scanner.go | 421 +++++++----------- explorer/scan/reporter.go | 15 +- explorer/scan/testdata/2pods.yaml | 32 ++ .../scan/testdata/3pods_with_duplicate.yaml | 48 ++ providers/assets.go | 110 ----- utils/slicesx/batch.go | 16 + utils/slicesx/batch_test.go | 43 ++ 10 files changed, 712 insertions(+), 398 deletions(-) create mode 100644 explorer/scan/discovery.go create mode 100644 explorer/scan/discovery_test.go create mode 100644 explorer/scan/testdata/2pods.yaml create mode 100644 explorer/scan/testdata/3pods_with_duplicate.yaml delete mode 100644 providers/assets.go create mode 100644 utils/slicesx/batch.go create mode 100644 utils/slicesx/batch_test.go diff --git a/apps/cnquery/cmd/shell.go b/apps/cnquery/cmd/shell.go index ce1a840f1c..6db9ee4a10 100644 --- a/apps/cnquery/cmd/shell.go +++ b/apps/cnquery/cmd/shell.go @@ -4,6 +4,7 @@ package cmd import ( + "context" "errors" "fmt" "os" @@ -17,6 +18,7 @@ import ( "go.mondoo.com/cnquery/v10/cli/config" "go.mondoo.com/cnquery/v10/cli/shell" "go.mondoo.com/cnquery/v10/cli/theme" + "go.mondoo.com/cnquery/v10/explorer/scan" "go.mondoo.com/cnquery/v10/providers" "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory/manager" @@ -113,21 +115,23 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error { log.Fatal().Err(err).Msg("could not load asset information") } - assets, err := providers.ProcessAssetCandidates(runtime, res, conf.UpstreamConfig, conf.PlatformID) + ctx := context.Background() + discoveredAssets, err := scan.DiscoverAssets(ctx, res.Inventory, conf.UpstreamConfig, providers.NullRecording{}) if err != nil { log.Fatal().Err(err).Msg("could not process assets") } - if len(assets) == 0 { + filteredAssets := discoveredAssets.GetAssetsByPlatformID(conf.PlatformID) + if len(filteredAssets) == 0 { log.Fatal().Msg("could not find an asset that we can connect to") } - connectAsset := assets[0] - if len(assets) > 1 { + connectAsset := filteredAssets[0] + if len(filteredAssets) > 1 { isTTY := isatty.IsTerminal(os.Stdout.Fd()) if isTTY { - connectAsset = components.AssetSelect(assets) + connectAsset = components.AssetSelect(filteredAssets) } else { - fmt.Println(components.AssetList(theme.OperatingSystemTheme, assets)) + fmt.Println(components.AssetList(theme.OperatingSystemTheme, filteredAssets)) log.Fatal().Msg("cannot connect to more than one asset, use --platform-id to select a specific asset") } } diff --git a/explorer/scan/discovery.go b/explorer/scan/discovery.go new file mode 100644 index 0000000000..19eaf52573 --- /dev/null +++ b/explorer/scan/discovery.go @@ -0,0 +1,185 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package scan + +import ( + "context" + "errors" + + "github.com/rs/zerolog/log" + "go.mondoo.com/cnquery/v10/cli/config" + "go.mondoo.com/cnquery/v10/cli/execruntime" + "go.mondoo.com/cnquery/v10/llx" + "go.mondoo.com/cnquery/v10/providers" + inventory "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory/manager" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/plugin" + "go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream" +) + +type AssetWithRuntime struct { + Asset *inventory.Asset + Runtime *providers.Runtime +} + +type AssetWithError struct { + Asset *inventory.Asset + Err error +} + +type DiscoveredAssets struct { + platformIds map[string]struct{} + Assets []*AssetWithRuntime + Errors []*AssetWithError +} + +// Add adds an asset and its runtime to the discovered assets list. It returns true if the +// asset has been added, false if it is a duplicate +func (d *DiscoveredAssets) Add(asset *inventory.Asset, runtime *providers.Runtime) bool { + isDuplicate := false + for _, platformId := range asset.PlatformIds { + if _, ok := d.platformIds[platformId]; ok { + isDuplicate = true + break + } + d.platformIds[platformId] = struct{}{} + } + if isDuplicate { + return false + } + + d.Assets = append(d.Assets, &AssetWithRuntime{Asset: asset, Runtime: runtime}) + return true +} + +func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) { + d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err}) +} + +func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory.Asset { + var assets []*inventory.Asset + for _, a := range d.Assets { + for _, p := range a.Asset.PlatformIds { + if platformID == "" || p == platformID { + assets = append(assets, a.Asset) + break + } + } + } + return assets +} + +// DiscoverAssets discovers assets from the given inventory and upstream configuration. Returns only unique assets +func DiscoverAssets(ctx context.Context, inv *inventory.Inventory, upstream *upstream.UpstreamConfig, recording llx.Recording) (*DiscoveredAssets, error) { + im, err := manager.NewManager(manager.WithInventory(inv, providers.DefaultRuntime())) + if err != nil { + return nil, errors.New("failed to resolve inventory for connection") + } + invAssets := im.GetAssets() + if len(invAssets) == 0 { + return nil, errors.New("could not find an asset that we can connect to") + } + + runtimeEnv := execruntime.Detect() + var runtimeLabels map[string]string + // If the runtime is an automated environment and the root asset is CI/CD, then we are doing a + // CI/CD scan and we need to apply the runtime labels to the assets + if runtimeEnv != nil && + runtimeEnv.IsAutomatedEnv() && + inv.Spec.Assets[0].Category == inventory.AssetCategory_CATEGORY_CICD { + runtimeLabels = runtimeEnv.Labels() + } + + discoveredAssets := &DiscoveredAssets{platformIds: map[string]struct{}{}} + + // we connect and perform discovery for each asset in the job inventory + for _, rootAsset := range invAssets { + resolvedRootAsset, err := im.ResolveAsset(rootAsset) + if err != nil { + return nil, err + } + + // create runtime for root asset + rootAssetWithRuntime, err := createRuntimeForAsset(resolvedRootAsset, upstream, recording) + if err != nil { + log.Error().Err(err).Str("asset", resolvedRootAsset.Name).Msg("unable to create runtime for asset") + discoveredAssets.AddError(rootAssetWithRuntime.Asset, err) + continue + } + + resolvedRootAsset = rootAssetWithRuntime.Asset // to ensure we get all the information the connect call gave us + + // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset + for _, a := range rootAssetWithRuntime.Runtime.Provider.Connection.Inventory.Spec.Assets { + // create runtime for root asset + assetWithRuntime, err := createRuntimeForAsset(a, upstream, recording) + if err != nil { + log.Error().Err(err).Str("asset", a.Name).Msg("unable to create runtime for asset") + discoveredAssets.AddError(assetWithRuntime.Asset, err) + continue + } + + resolvedAsset := assetWithRuntime.Runtime.Provider.Connection.Asset + prepareAsset(resolvedAsset, resolvedRootAsset, runtimeLabels) + + // If the asset has been already added, we should close its runtime + if !discoveredAssets.Add(resolvedAsset, assetWithRuntime.Runtime) { + assetWithRuntime.Runtime.Close() + } + } + } + + // if there is exactly one asset, assure that the --asset-name is used + // TODO: make it so that the --asset-name is set for the root asset only even if multiple assets are there + // This is a temporary fix that only works if there is only one asset + if len(discoveredAssets.Assets) == 1 && invAssets[0].Name != "" && invAssets[0].Name != discoveredAssets.Assets[0].Asset.Name { + log.Debug().Str("asset", discoveredAssets.Assets[0].Asset.Name).Msg("Overriding asset name with --asset-name flag") + discoveredAssets.Assets[0].Asset.Name = invAssets[0].Name + } + + return discoveredAssets, nil +} + +func createRuntimeForAsset(asset *inventory.Asset, upstream *upstream.UpstreamConfig, recording llx.Recording) (*AssetWithRuntime, error) { + var runtime *providers.Runtime + var err error + // Close the runtime if an error occured + defer func() { + if err != nil && runtime != nil { + runtime.Close() + } + }() + + runtime, err = providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime()) + if err != nil { + return nil, err + } + if err = runtime.SetRecording(recording); err != nil { + return nil, err + } + + err = runtime.Connect(&plugin.ConnectReq{ + Features: config.Features, + Asset: asset, + Upstream: upstream, + }) + if err != nil { + return nil, err + } + return &AssetWithRuntime{Asset: runtime.Provider.Connection.Asset, Runtime: runtime}, nil +} + +// prepareAsset prepares the asset for further processing by adding mondoo-specific labels and annotations +func prepareAsset(a *inventory.Asset, rootAsset *inventory.Asset, runtimeLabels map[string]string) { + a.AddMondooLabels(rootAsset) + a.AddAnnotations(rootAsset.GetAnnotations()) + a.ManagedBy = rootAsset.ManagedBy + a.KindString = a.GetPlatform().Kind + for k, v := range runtimeLabels { + if a.Labels == nil { + a.Labels = map[string]string{} + } + a.Labels[k] = v + } +} diff --git a/explorer/scan/discovery_test.go b/explorer/scan/discovery_test.go new file mode 100644 index 0000000000..3c0ecb5128 --- /dev/null +++ b/explorer/scan/discovery_test.go @@ -0,0 +1,224 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package scan + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mondoo.com/cnquery/v10/providers" + inventory "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" +) + +func TestDiscoveredAssets_Add(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + Assets: []*AssetWithRuntime{}, + Errors: []*AssetWithError{}, + } + asset := &inventory.Asset{ + PlatformIds: []string{"platform1"}, + } + runtime := &providers.Runtime{} + + assert.True(t, d.Add(asset, runtime)) + assert.Len(t, d.Assets, 1) + assert.Len(t, d.Errors, 0) + + // Make sure adding duplicates is not possible + assert.False(t, d.Add(asset, runtime)) + assert.Len(t, d.Assets, 1) + assert.Len(t, d.Errors, 0) +} + +func TestDiscoveredAssets_Add_MultiplePlatformIDs(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + Assets: []*AssetWithRuntime{}, + Errors: []*AssetWithError{}, + } + asset := &inventory.Asset{ + PlatformIds: []string{"platform1", "platform2"}, + } + runtime := &providers.Runtime{} + + assert.True(t, d.Add(asset, runtime)) + assert.Len(t, d.Assets, 1) + assert.Len(t, d.Errors, 0) + + // Make sure adding duplicates is not possible + assert.False(t, d.Add(&inventory.Asset{ + PlatformIds: []string{"platform3", asset.PlatformIds[0]}, + }, runtime)) + assert.Len(t, d.Assets, 1) + assert.Len(t, d.Errors, 0) +} + +func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + Assets: []*AssetWithRuntime{}, + Errors: []*AssetWithError{}, + } + + allPlatformIds := []string{} + for i := 0; i < 10; i++ { + pId := fmt.Sprintf("platform1%d", i) + allPlatformIds = append(allPlatformIds, pId) + asset := &inventory.Asset{ + PlatformIds: []string{pId}, + } + runtime := &providers.Runtime{} + + assert.True(t, d.Add(asset, runtime)) + } + assert.Len(t, d.Assets, 10) + + // Make sure adding duplicates is not possible + assets := d.GetAssetsByPlatformID(allPlatformIds[0]) + assert.Len(t, assets, 1) + assert.Equal(t, allPlatformIds[0], assets[0].PlatformIds[0]) +} + +func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) { + d := &DiscoveredAssets{ + platformIds: map[string]struct{}{}, + Assets: []*AssetWithRuntime{}, + Errors: []*AssetWithError{}, + } + + allPlatformIds := []string{} + for i := 0; i < 10; i++ { + pId := fmt.Sprintf("platform1%d", i) + allPlatformIds = append(allPlatformIds, pId) + asset := &inventory.Asset{ + PlatformIds: []string{pId}, + } + runtime := &providers.Runtime{} + + assert.True(t, d.Add(asset, runtime)) + } + assert.Len(t, d.Assets, 10) + + // Make sure adding duplicates is not possible + assets := d.GetAssetsByPlatformID("") + assert.Len(t, assets, 10) + platformIds := []string{} + for _, a := range assets { + platformIds = append(platformIds, a.PlatformIds[0]) + } + assert.ElementsMatch(t, allPlatformIds, platformIds) +} + +func TestDiscoverAssets(t *testing.T) { + getInventory := func() *inventory.Inventory { + return &inventory.Inventory{ + Spec: &inventory.InventorySpec{ + Assets: []*inventory.Asset{ + { + Connections: []*inventory.Config{ + { + Type: "k8s", + Options: map[string]string{ + "path": "./testdata/2pods.yaml", + }, + Discover: &inventory.Discovery{ + Targets: []string{"auto"}, + }, + }, + }, + ManagedBy: "mondoo-operator-123", + }, + }, + }, + } + } + + t.Run("normal", func(t *testing.T) { + inv := getInventory() + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.Errors, 0) + assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[0].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[1].Asset.ManagedBy) + assert.Equal(t, "mondoo-operator-123", discoveredAssets.Assets[2].Asset.ManagedBy) + }) + + t.Run("with duplicate root assets", func(t *testing.T) { + inv := getInventory() + inv.Spec.Assets = append(inv.Spec.Assets, inv.Spec.Assets[0]) + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + + // Make sure no duplicates are returned + assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.Errors, 0) + }) + + t.Run("with duplicate discovered assets", func(t *testing.T) { + inv := getInventory() + inv.Spec.Assets[0].Connections[0].Options["path"] = "./testdata/3pods_with_duplicate.yaml" + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + + // Make sure no duplicates are returned + assert.Len(t, discoveredAssets.Assets, 3) + assert.Len(t, discoveredAssets.Errors, 0) + }) + + t.Run("copy root asset annotations", func(t *testing.T) { + inv := getInventory() + inv.Spec.Assets[0].Annotations = map[string]string{ + "key1": "value1", + "key2": "value2", + } + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + + for _, asset := range discoveredAssets.Assets { + for k, v := range inv.Spec.Assets[0].Annotations { + require.Contains(t, asset.Asset.Annotations, k) + assert.Equal(t, v, asset.Asset.Annotations[k]) + } + } + }) + + t.Run("copy root asset managedBy", func(t *testing.T) { + inv := getInventory() + inv.Spec.Assets[0].ManagedBy = "managed-by-test" + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + + for _, asset := range discoveredAssets.Assets { + assert.Equal(t, inv.Spec.Assets[0].ManagedBy, asset.Asset.ManagedBy) + } + }) + + t.Run("set ci/cd labels", func(t *testing.T) { + inv := getInventory() + + val, isSet := os.LookupEnv("GITHUB_ACTION") + defer func() { + if isSet { + require.NoError(t, os.Setenv("GITHUB_ACTION", val)) + } else { + require.NoError(t, os.Unsetenv("GITHUB_ACTION")) + } + }() + inv.Spec.Assets[0].Category = inventory.AssetCategory_CATEGORY_CICD + require.NoError(t, os.Setenv("GITHUB_ACTION", "go-test")) + discoveredAssets, err := DiscoverAssets(context.Background(), inv, nil, providers.NullRecording{}) + require.NoError(t, err) + + for _, asset := range discoveredAssets.Assets { + require.Contains(t, asset.Asset.Labels, "mondoo.com/exec-environment") + assert.Equal(t, "actions.github.com", asset.Asset.Labels["mondoo.com/exec-environment"]) + } + }) +} diff --git a/explorer/scan/local_scanner.go b/explorer/scan/local_scanner.go index ab9ae90fd0..cc998ce0fc 100644 --- a/explorer/scan/local_scanner.go +++ b/explorer/scan/local_scanner.go @@ -12,15 +12,11 @@ import ( sync "sync" "time" - "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory/manager" - "go.mondoo.com/cnquery/v10/providers-sdk/v1/plugin" - "github.com/mattn/go-isatty" "github.com/rs/zerolog/log" "github.com/segmentio/ksuid" "github.com/spf13/viper" "go.mondoo.com/cnquery/v10" - "go.mondoo.com/cnquery/v10/cli/config" "go.mondoo.com/cnquery/v10/cli/progress" "go.mondoo.com/cnquery/v10/explorer" "go.mondoo.com/cnquery/v10/explorer/executor" @@ -33,21 +29,12 @@ import ( "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" "go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream" "go.mondoo.com/cnquery/v10/utils/multierr" + "go.mondoo.com/cnquery/v10/utils/slicesx" "go.mondoo.com/ranger-rpc/codes" "go.mondoo.com/ranger-rpc/status" "google.golang.org/protobuf/proto" ) -type assetWithRuntime struct { - asset *inventory.Asset - runtime *providers.Runtime -} - -type assetWithError struct { - asset *inventory.Asset - err error -} - type LocalScanner struct { fetcher *fetcher upstream *upstream.UpstreamConfig @@ -108,7 +95,7 @@ func (s *LocalScanner) Run(ctx context.Context, job *Job) (*explorer.ReportColle if err != nil { return nil, err } - reports, _, err := s.distributeJob(job, ctx, upstreamConfig) + reports, err := s.distributeJob(job, ctx, upstreamConfig) if err != nil { if code := status.Code(err); code == codes.Unauthenticated { return nil, multierr.Wrap(err, @@ -155,7 +142,7 @@ func (s *LocalScanner) RunIncognito(ctx context.Context, job *Job) (*explorer.Re // skip the error check, we are running in incognito upstreamConfig, _ := s.getUpstreamConfig(job.Inventory, true) - reports, _, err := s.distributeJob(job, ctx, upstreamConfig) + reports, err := s.distributeJob(job, ctx, upstreamConfig) if err != nil { return nil, err } @@ -184,295 +171,187 @@ func preprocessQueryPackFilters(filters []string) []string { return res } -func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, bool, error) { +func CreateProgressBar(discoveredAssets *DiscoveredAssets, disableProgressBar bool) (progress.MultiProgress, error) { + var multiprogress progress.MultiProgress + if isatty.IsTerminal(os.Stdout.Fd()) && !disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") { + progressBarElements := map[string]string{} + orderedKeys := []string{} + for i := range discoveredAssets.Assets { + asset := discoveredAssets.Assets[i].Asset + // this shouldn't happen, but might + // it normally indicates a bug in the provider + if presentAsset, present := progressBarElements[asset.PlatformIds[0]]; present { + return nil, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, asset.Name, asset.PlatformIds[0]) + } + progressBarElements[asset.PlatformIds[0]] = asset.Name + orderedKeys = append(orderedKeys, asset.PlatformIds[0]) + } + var err error + multiprogress, err = progress.NewMultiProgressBars(progressBarElements, orderedKeys, progress.WithScore()) + if err != nil { + return nil, multierr.Wrap(err, "failed to create progress bars") + } + } else { + // TODO: adjust naming + multiprogress = progress.NoopMultiProgressBars{} + } + return multiprogress, nil +} + +func (s *LocalScanner) distributeJob(job *Job, ctx context.Context, upstream *upstream.UpstreamConfig) (*explorer.ReportCollection, error) { log.Info().Msgf("discover related assets for %d asset(s)", len(job.Inventory.Spec.Assets)) // Always shut down the coordinator, to make sure providers are killed defer providers.Coordinator.Shutdown() - im, err := manager.NewManager(manager.WithInventory(job.Inventory, providers.DefaultRuntime())) + discoveredAssets, err := DiscoverAssets(ctx, job.Inventory, upstream, s.recording) if err != nil { - return nil, false, errors.New("failed to resolve inventory for connection") - } - assetList := im.GetAssets() - - var assets []*assetWithRuntime - // note: asset candidate runtimes are the runtime that discovered them - var assetCandidates []*assetWithRuntime - var assetErrors []*assetWithError - - // we connect and perform discovery for each asset in the job inventory - for i := range assetList { - asset := assetList[i] - resolvedAsset, err := im.ResolveAsset(asset) - if err != nil { - return nil, false, err - } - - runtime, err := providers.Coordinator.RuntimeFor(asset, providers.DefaultRuntime()) - if err != nil { - log.Error().Err(err).Str("asset", asset.Name).Msg("unable to create runtime for asset") - assetErrors = append(assetErrors, &assetWithError{ - asset: resolvedAsset, - err: err, - }) - continue - } - runtime.SetRecording(s.recording) - - if err := runtime.Connect(&plugin.ConnectReq{ - Features: cnquery.GetFeatures(ctx), - Asset: resolvedAsset, - Upstream: upstream, - }); err != nil { - log.Error().Err(err).Msg("unable to connect to asset") - assetErrors = append(assetErrors, &assetWithError{ - asset: resolvedAsset, - err: err, - }) - continue - } - asset = runtime.Provider.Connection.Asset // to ensure we get all the information the connect call gave us - - // for all discovered assets, we apply mondoo-specific labels and annotations that come from the root asset - for _, a := range runtime.Provider.Connection.GetInventory().GetSpec().GetAssets() { - a.AddMondooLabels(asset) - a.AddAnnotations(asset.GetAnnotations()) - } - processedAssets, err := providers.ProcessAssetCandidates(runtime, runtime.Provider.Connection, upstream, "") - if err != nil { - assetErrors = append(assetErrors, &assetWithError{ - asset: resolvedAsset, - err: err, - }) - continue - } - for i := range processedAssets { - assetCandidates = append(assetCandidates, &assetWithRuntime{ - asset: processedAssets[i], - runtime: runtime, - }) - } - // TODO: we want to keep better track of errors, since there may be - // multiple assets coming in. It's annoying to abort the scan if we get one - // error at this stage. - - // we grab the asset from the connection, because it contains all the - // detected metadata (and IDs) - // assets = append(assets, runtime.Provider.Connection.Asset) + return nil, err } - // For each asset candidate, we initialize a new runtime and connect to it. + // For each discovered asset, we initialize a new runtime and connect to it. // Within this process, we set up a catch-all deferred function, that shuts - // down all runtimes, in case we exit early. The list of assets only gets - // set in the block below this deferred function. + // down all runtimes, in case we exit early. defer func() { - for i := range assets { - asset := assets[i] + for _, asset := range discoveredAssets.Assets { // we can call close multiple times and it will only execute once - if asset.runtime != nil { - asset.runtime.Close() + if asset.Runtime != nil { + asset.Runtime.Close() } } }() - for i := range assetCandidates { - candidate := assetCandidates[i] - - var runtime *providers.Runtime - if candidate.asset.Connections[0].Type == "k8s" { - runtime, err = providers.Coordinator.RuntimeFor(candidate.asset, providers.DefaultRuntime()) - if err != nil { - return nil, false, err - } - } else { - runtime, err = providers.Coordinator.EphemeralRuntimeFor(candidate.asset) - if err != nil { - return nil, false, err - } - } - err = runtime.SetRecording(candidate.runtime.Recording()) - if err != nil { - log.Error().Err(err).Msg("unable to set recording for asset (pre-connect)") - continue - } - - err = runtime.Connect(&plugin.ConnectReq{ - Features: config.Features, - Asset: candidate.asset, - Upstream: upstream, - }) - candidate.asset = runtime.Provider.Connection.Asset // to ensure we get all the information the connect call gave us - if err != nil { - log.Error().Err(err).Str("asset", candidate.asset.Name).Msg("unable to connect to asset") - continue - } - - if candidate.asset.GetPlatform() == nil { - log.Error().Msgf("unable to detect platform for asset " + candidate.asset.Name) - continue - } - assets = append(assets, &assetWithRuntime{ - asset: candidate.asset, - runtime: runtime, - }) - } - - // if there is exactly one asset, assure that the --asset-name is used - // TODO: make it so that the --asset-name is set for the root asset only even if multiple assets are there - // This is a temporary fix that only works if there is only one asset - if len(assets) == 1 && assetList[0].Name != "" && assetList[0].Name != assets[0].asset.Name { - log.Debug().Str("asset", assets[0].asset.Name).Msg("Overriding asset name with --asset-name flag") - assets[0].asset.Name = assetList[0].Name - } - - justAssets := []*inventory.Asset{} - for _, asset := range assets { - asset.asset.KindString = asset.asset.GetPlatform().Kind - justAssets = append(justAssets, asset.asset) - } - for _, asset := range assetErrors { - justAssets = append(justAssets, asset.asset) - } - // plan scan jobs - reporter := NewAggregateReporter(justAssets) + reporter := NewAggregateReporter() // if we had asset errors we want to place them into the reporter - for i := range assetErrors { - reporter.AddScanError(assetErrors[i].asset, assetErrors[i].err) + for i := range discoveredAssets.Errors { + reporter.AddScanError(discoveredAssets.Errors[i].Asset, discoveredAssets.Errors[i].Err) } - if len(assets) == 0 { - return reporter.Reports(), false, nil + if len(discoveredAssets.Assets) == 0 { + return reporter.Reports(), nil } - // sync assets - if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { - log.Info().Msg("synchronize assets") - client, err := upstream.InitClient() - if err != nil { - return nil, false, err - } - - services, err := explorer.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient) - if err != nil { - return nil, false, err + multiprogress, err := CreateProgressBar(discoveredAssets, s.disableProgressBar) + if err != nil { + return nil, err + } + // start the progress bar + scanGroups := sync.WaitGroup{} + scanGroups.Add(1) + go func() { + defer scanGroups.Done() + if err := multiprogress.Open(); err != nil { + log.Error().Err(err).Msg("failed to open progress bar") } + }() - inventory.DeprecatedV8CompatAssets(justAssets) - resp, err := services.SynchronizeAssets(ctx, &explorer.SynchronizeAssetsReq{ - SpaceMrn: client.SpaceMrn, - List: justAssets, - }) - if err != nil { - return nil, false, err - } - log.Debug().Int("assets", len(resp.Details)).Msg("got assets details") - platformAssetMapping := make(map[string]*explorer.SynchronizeAssetsRespAssetDetail) - for i := range resp.Details { - log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping") - platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i] - } + assetBatches := slicesx.Batch(discoveredAssets.Assets, 100) + for i := range assetBatches { + batch := assetBatches[i] - // attach the asset details to the assets list - for i := range assets { - log.Debug().Str("asset", assets[i].asset.Name).Strs("platform-ids", assets[i].asset.PlatformIds).Msg("update asset") - platformMrn := assets[i].asset.PlatformIds[0] - assets[i].asset.Mrn = platformAssetMapping[platformMrn].AssetMrn - assets[i].asset.Url = platformAssetMapping[platformMrn].Url - } - } else { - // ensure we have non-empty asset MRNs - for i := range assets { - cur := assets[i] - if cur.asset.Mrn == "" { - randID := "//" + explorer.SERVICE_NAME + "/" + explorer.MRN_RESOURCE_ASSET + "/" + ksuid.New().String() - x, err := mrn.NewMRN(randID) - if err != nil { - return nil, false, multierr.Wrap(err, "failed to generate a random asset MRN") - } - cur.asset.Mrn = x.String() + // sync assets + if upstream != nil && upstream.ApiEndpoint != "" && !upstream.Incognito { + log.Info().Msg("synchronize assets") + client, err := upstream.InitClient() + if err != nil { + return nil, err } - } - } - - // if a bundle was provided check that it matches the filter, bundles can also be downloaded - // later therefore we do not want to stop execution here - if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { - return nil, false, errors.New("all available packs filtered out. nothing to do") - } - progressBarElements := map[string]string{} - orderedKeys := []string{} - for i := range assets { - // this shouldn't happen, but might - // it normally indicates a bug in the provider - if presentAsset, present := progressBarElements[assets[i].asset.PlatformIds[0]]; present { - return reporter.Reports(), false, fmt.Errorf("asset %s and %s have the same platform id %s", presentAsset, assets[i].asset.Name, assets[i].asset.PlatformIds[0]) - } - progressBarElements[assets[i].asset.PlatformIds[0]] = assets[i].asset.Name - orderedKeys = append(orderedKeys, assets[i].asset.PlatformIds[0]) - } - var multiprogress progress.MultiProgress - if isatty.IsTerminal(os.Stdout.Fd()) && !s.disableProgressBar && !strings.EqualFold(logger.GetLevel(), "debug") && !strings.EqualFold(logger.GetLevel(), "trace") { - var err error - multiprogress, err = progress.NewMultiProgressBars(progressBarElements, orderedKeys) - if err != nil { - return nil, false, multierr.Wrap(err, "failed to create progress bars") - } - } else { - // TODO: adjust naming - multiprogress = progress.NoopMultiProgressBars{} - } + services, err := explorer.NewRemoteServices(client.ApiEndpoint, client.Plugins, client.HttpClient) + if err != nil { + return nil, err + } - scanGroup := sync.WaitGroup{} - scanGroup.Add(1) - finished := false - go func() { - defer scanGroup.Done() - for i := range assets { - asset := assets[i].asset - runtime := assets[i].runtime - - // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, - // we need to solve this at a different level. - select { - case <-ctx.Done(): - log.Warn().Msg("request context has been canceled") - // When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset - multiprogress.Close() - return - default: + assetsToSync := make([]*inventory.Asset, 0, len(batch)) + for i := range batch { + assetsToSync = append(assetsToSync, batch[i].Asset) } - p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress} - s.RunAssetJob(&AssetJob{ - DoRecord: job.DoRecord, - UpstreamConfig: upstream, - Asset: asset, - Bundle: job.Bundle, - Props: job.Props, - QueryPackFilters: preprocessQueryPackFilters(job.QueryPackFilters), - Ctx: ctx, - Reporter: reporter, - ProgressReporter: p, - runtime: runtime, + resp, err := services.SynchronizeAssets(ctx, &explorer.SynchronizeAssetsReq{ + SpaceMrn: client.SpaceMrn, + List: assetsToSync, }) + if err != nil { + return nil, err + } + log.Debug().Int("assets", len(resp.Details)).Msg("got assets details") + platformAssetMapping := make(map[string]*explorer.SynchronizeAssetsRespAssetDetail) + for i := range resp.Details { + log.Debug().Str("platform-mrn", resp.Details[i].PlatformMrn).Str("asset", resp.Details[i].AssetMrn).Msg("asset mapping") + platformAssetMapping[resp.Details[i].PlatformMrn] = resp.Details[i] + } - // runtimes are single-use only. Close them once they are done. - runtime.Close() + // attach the asset details to the assets list + for i := range batch { + asset := batch[i].Asset + log.Debug().Str("asset", asset.Name).Strs("platform-ids", asset.PlatformIds).Msg("update asset") + platformMrn := asset.PlatformIds[0] + asset.Mrn = platformAssetMapping[platformMrn].AssetMrn + asset.Url = platformAssetMapping[platformMrn].Url + } + } else { + // ensure we have non-empty asset MRNs + for i := range batch { + asset := batch[i].Asset + if asset.Mrn == "" { + randID := "//" + explorer.SERVICE_NAME + "/" + explorer.MRN_RESOURCE_ASSET + "/" + ksuid.New().String() + x, err := mrn.NewMRN(randID) + if err != nil { + return nil, multierr.Wrap(err, "failed to generate a random asset MRN") + } + asset.Mrn = x.String() + } + } } - finished = true - }() - scanGroup.Add(1) - go func() { - defer scanGroup.Done() - multiprogress.Open() - }() + // if a bundle was provided check that it matches the filter, bundles can also be downloaded + // later therefore we do not want to stop execution here + if job.Bundle != nil && job.Bundle.FilterQueryPacks(job.QueryPackFilters) { + return nil, errors.New("all available packs filtered out. nothing to do") + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for i := range batch { + asset := batch[i].Asset + runtime := batch[i].Runtime + + // Make sure the context has not been canceled in the meantime. Note that this approach works only for single threaded execution. If we have more than 1 thread calling this function, + // we need to solve this at a different level. + select { + case <-ctx.Done(): + log.Warn().Msg("request context has been canceled") + // When we scan concurrently, we need to call Errored(asset.Mrn) status for this asset + multiprogress.Close() + return + default: + } - scanGroup.Wait() - return reporter.Reports(), finished, nil + p := &progress.MultiProgressAdapter{Key: asset.PlatformIds[0], Multi: multiprogress} + s.RunAssetJob(&AssetJob{ + DoRecord: job.DoRecord, + UpstreamConfig: upstream, + Asset: asset, + Bundle: job.Bundle, + Props: job.Props, + QueryPackFilters: preprocessQueryPackFilters(job.QueryPackFilters), + Ctx: ctx, + Reporter: reporter, + ProgressReporter: p, + runtime: runtime, + }) + + // runtimes are single-use only. Close them once they are done. + runtime.Close() + } + }() + wg.Wait() + } + scanGroups.Wait() + return reporter.Reports(), nil } func (s *LocalScanner) RunAssetJob(job *AssetJob) { diff --git a/explorer/scan/reporter.go b/explorer/scan/reporter.go index f53df398e1..ff90e38222 100644 --- a/explorer/scan/reporter.go +++ b/explorer/scan/reporter.go @@ -29,18 +29,9 @@ type AggregateReporter struct { resolved map[string]*explorer.ResolvedPack } -func NewAggregateReporter(assetList []*inventory.Asset) *AggregateReporter { - assets := make(map[string]*explorer.Asset, len(assetList)) - for i := range assetList { - cur := assetList[i] - assets[cur.Mrn] = &explorer.Asset{ - Mrn: cur.Mrn, - Name: cur.Name, - } - } - +func NewAggregateReporter() *AggregateReporter { return &AggregateReporter{ - assets: assets, + assets: map[string]*explorer.Asset{}, assetReports: map[string]*explorer.Report{}, assetErrors: map[string]error{}, resolved: map[string]*explorer.ResolvedPack{}, @@ -48,12 +39,14 @@ func NewAggregateReporter(assetList []*inventory.Asset) *AggregateReporter { } func (r *AggregateReporter) AddReport(asset *inventory.Asset, results *AssetReport) { + r.assets[asset.Mrn] = &explorer.Asset{Name: asset.Name, Mrn: asset.Mrn} r.assetReports[asset.Mrn] = results.Report r.resolved[asset.Mrn] = results.Resolved r.bundle = results.Bundle } func (r *AggregateReporter) AddScanError(asset *inventory.Asset, err error) { + r.assets[asset.Mrn] = &explorer.Asset{Name: asset.Name, Mrn: asset.Mrn} r.assetErrors[asset.Mrn] = err } diff --git a/explorer/scan/testdata/2pods.yaml b/explorer/scan/testdata/2pods.yaml new file mode 100644 index 0000000000..bcb4edfe1e --- /dev/null +++ b/explorer/scan/testdata/2pods.yaml @@ -0,0 +1,32 @@ +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + admission-result: pass + name: passing-pod-yaml + namespace: default +spec: + automountServiceAccountToken: false + containers: + - image: ubuntu:20.04 + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: ["sleep 6000"] + name: ubuntu +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + admission-result: pass + name: passing-pod-yaml-2 + namespace: default +spec: + automountServiceAccountToken: false + containers: + - image: ubuntu:20.04 + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: ["sleep 6000"] + name: ubuntu \ No newline at end of file diff --git a/explorer/scan/testdata/3pods_with_duplicate.yaml b/explorer/scan/testdata/3pods_with_duplicate.yaml new file mode 100644 index 0000000000..b469d08dc0 --- /dev/null +++ b/explorer/scan/testdata/3pods_with_duplicate.yaml @@ -0,0 +1,48 @@ +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + admission-result: pass + name: passing-pod-yaml + namespace: default +spec: + automountServiceAccountToken: false + containers: + - image: ubuntu:20.04 + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: ["sleep 6000"] + name: ubuntu +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + admission-result: pass + name: passing-pod-yaml-2 + namespace: default +spec: + automountServiceAccountToken: false + containers: + - image: ubuntu:20.04 + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: ["sleep 6000"] + name: ubuntu +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + admission-result: pass + name: passing-pod-yaml-2 + namespace: default +spec: + automountServiceAccountToken: false + containers: + - image: ubuntu:20.04 + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: ["sleep 6000"] + name: ubuntu \ No newline at end of file diff --git a/providers/assets.go b/providers/assets.go deleted file mode 100644 index e1198e1efc..0000000000 --- a/providers/assets.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Mondoo, Inc. -// SPDX-License-Identifier: BUSL-1.1 - -package providers - -import ( - "github.com/cockroachdb/errors" - "github.com/rs/zerolog/log" - "go.mondoo.com/cnquery/v10/cli/config" - "go.mondoo.com/cnquery/v10/logger" - "go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory" - pp "go.mondoo.com/cnquery/v10/providers-sdk/v1/plugin" - "go.mondoo.com/cnquery/v10/providers-sdk/v1/upstream" -) - -func ProcessAssetCandidates(runtime *Runtime, connectRes *pp.ConnectRes, upstreamConfig *upstream.UpstreamConfig, platformID string) ([]*inventory.Asset, error) { - var assetCandidates []*inventory.Asset - if connectRes.Inventory == nil || connectRes.Inventory.Spec == nil { - return []*inventory.Asset{connectRes.Asset}, nil - } else { - logger.DebugDumpJSON("inventory-resolved", connectRes.Inventory) - assetCandidates = connectRes.Inventory.Spec.Assets - } - log.Debug().Msgf("resolved %d assets", len(assetCandidates)) - - if err := detectAssets(runtime, assetCandidates, upstreamConfig); err != nil { - return nil, err - } - - if platformID != "" { - res, err := filterAssetByPlatformID(assetCandidates, platformID) - if err != nil { - return nil, err - } - return []*inventory.Asset{res}, nil - } - - return filterUniqueAssets(assetCandidates), nil -} - -// detectAssets connects to all assets that do not have a platform ID yet -func detectAssets(runtime *Runtime, assetCandidates []*inventory.Asset, upstreamConfig *upstream.UpstreamConfig) error { - for i := range assetCandidates { - asset := assetCandidates[i] - // If the assets have platform IDs, then we have already connected to them via the - // current provider. - if len(asset.PlatformIds) > 0 { - continue - } - - // Make sure the provider for the asset is present - if err := runtime.DetectProvider(asset); err != nil { - return err - } - - err := runtime.Connect(&pp.ConnectReq{ - Features: config.Features, - Asset: asset, - Upstream: upstreamConfig, - }) - if err != nil { - continue - } - // Use the updated asset - assetCandidates[i] = runtime.Provider.Connection.Asset - } - return nil -} - -func filterAssetByPlatformID(assetList []*inventory.Asset, selectionID string) (*inventory.Asset, error) { - var foundAsset *inventory.Asset - for i := range assetList { - assetObj := assetList[i] - for j := range assetObj.PlatformIds { - if assetObj.PlatformIds[j] == selectionID { - return assetObj, nil - } - } - } - - if foundAsset == nil { - return nil, errors.New("could not find an asset with the provided identifier: " + selectionID) - } - return foundAsset, nil -} - -// filterUniqueAssets filters assets with duplicate platform IDs -func filterUniqueAssets(assetCandidates []*inventory.Asset) []*inventory.Asset { - uniqueAssets := []*inventory.Asset{} - platformIds := map[string]struct{}{} - for _, asset := range assetCandidates { - found := false - for _, platformId := range asset.PlatformIds { - if _, ok := platformIds[platformId]; ok { - found = true - log.Debug().Msgf("skipping asset %s with duplicate platform ID %s", asset.Name, platformId) - break - } - } - if found { - continue - } - - uniqueAssets = append(uniqueAssets, asset) - for _, platformId := range asset.PlatformIds { - platformIds[platformId] = struct{}{} - } - } - return uniqueAssets -} diff --git a/utils/slicesx/batch.go b/utils/slicesx/batch.go new file mode 100644 index 0000000000..c187008931 --- /dev/null +++ b/utils/slicesx/batch.go @@ -0,0 +1,16 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package slicesx + +func Batch[T any](list []T, batchSize int) [][]T { + var res [][]T + for i := 0; i < len(list); i += batchSize { + end := i + batchSize + if end > len(list) { + end = len(list) + } + res = append(res, list[i:end]) + } + return res +} diff --git a/utils/slicesx/batch_test.go b/utils/slicesx/batch_test.go new file mode 100644 index 0000000000..73c8728c8a --- /dev/null +++ b/utils/slicesx/batch_test.go @@ -0,0 +1,43 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package slicesx + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBatch(t *testing.T) { + slice := []string{} + for i := 0; i < 100; i++ { + slice = append(slice, fmt.Sprintf("item-%d", i)) + } + + batches := Batch(slice, 10) + assert.Len(t, batches, 10) + + flattenedBatches := []string{} + for _, batch := range batches { + flattenedBatches = append(flattenedBatches, batch...) + } + assert.Equal(t, slice, flattenedBatches) +} + +func TestBatch_Uneven(t *testing.T) { + slice := []string{} + for i := 0; i < 101; i++ { + slice = append(slice, fmt.Sprintf("item-%d", i)) + } + + batches := Batch(slice, 10) + assert.Len(t, batches, 11) + + flattenedBatches := []string{} + for _, batch := range batches { + flattenedBatches = append(flattenedBatches, batch...) + } + assert.Equal(t, slice, flattenedBatches) +}