Skip to content

Commit

Permalink
🐛 fix shell command (#3342)
Browse files Browse the repository at this point in the history
  • Loading branch information
imilchev authored Feb 17, 2024
1 parent e2a9235 commit 6edda62
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 27 deletions.
26 changes: 13 additions & 13 deletions apps/cnquery/cmd/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,32 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error {

connectAsset := filteredAssets[0]
if len(filteredAssets) > 1 {
invAssets := make([]*inventory.Asset, 0, len(filteredAssets))
for _, a := range filteredAssets {
invAssets = append(invAssets, a.Asset)
}

isTTY := isatty.IsTerminal(os.Stdout.Fd())
if isTTY {
connectAsset = components.AssetSelect(filteredAssets)
selectedAsset := components.AssetSelect(invAssets)
connectAsset = filteredAssets[selectedAsset]
} else {
fmt.Println(components.AssetList(theme.OperatingSystemTheme, filteredAssets))
fmt.Println(components.AssetList(theme.OperatingSystemTheme, invAssets))
log.Fatal().Msg("cannot connect to more than one asset, use --platform-id to select a specific asset")
}
}

if connectAsset == nil {
log.Fatal().Msg("no asset selected")
log.Error().Msg("no asset selected")
os.Exit(1)
}

err = runtime.Connect(&plugin.ConnectReq{
Features: conf.Features,
Asset: connectAsset,
Upstream: conf.UpstreamConfig,
})
if err != nil {
log.Fatal().Err(err).Msg("failed to connect to asset")
}
log.Info().Msgf("connected to %s", runtime.Provider.Connection.Asset.Platform.Title)
log.Info().Msgf("connected to %s", connectAsset.Runtime.Provider.Connection.Asset.Platform.Title)

// when we close the shell, we need to close the backend and store the recording
onCloseHandler := func() {
runtime.Close()
connectAsset.Runtime.Close()
providers.Coordinator.Shutdown()
}

Expand All @@ -142,7 +142,7 @@ func StartShell(runtime *providers.Runtime, conf *ShellConfig) error {
shellOptions = append(shellOptions, shell.WithFeatures(conf.Features))
shellOptions = append(shellOptions, shell.WithUpstreamConfig(conf.UpstreamConfig))

sh, err := shell.New(runtime, shellOptions...)
sh, err := shell.New(connectAsset.Runtime, shellOptions...)
if err != nil {
log.Error().Err(err).Msg("failed to initialize interactive shell")
}
Expand Down
6 changes: 3 additions & 3 deletions cli/components/assetselect.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"go.mondoo.com/cnquery/v10/providers-sdk/v1/inventory"
)

func AssetSelect(assetList []*inventory.Asset) *inventory.Asset {
func AssetSelect(assetList []*inventory.Asset) int {
list := make([]string, len(assetList))

// map asset name to list
Expand All @@ -36,9 +36,9 @@ func AssetSelect(assetList []*inventory.Asset) *inventory.Asset {
}

if selection == -1 {
return nil
return -1
}
selected := assetList[selection]
log.Info().Int("selection", selection).Str("asset", selected.Name).Msg("selected asset")
return selected
return selection
}
6 changes: 3 additions & 3 deletions explorer/scan/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ func (d *DiscoveredAssets) AddError(asset *inventory.Asset, err error) {
d.Errors = append(d.Errors, &AssetWithError{Asset: asset, Err: err})
}

func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*inventory.Asset {
var assets []*inventory.Asset
func (d *DiscoveredAssets) GetAssetsByPlatformID(platformID string) []*AssetWithRuntime {
var assets []*AssetWithRuntime
for _, a := range d.Assets {
for _, p := range a.Asset.PlatformIds {
if platformID == "" || p == platformID {
assets = append(assets, a.Asset)
assets = append(assets, a)
break
}
}
Expand Down
4 changes: 2 additions & 2 deletions explorer/scan/discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestDiscoveredAssets_GetAssetsByPlatformID(t *testing.T) {
// Make sure adding duplicates is not possible
assets := d.GetAssetsByPlatformID(allPlatformIds[0])
assert.Len(t, assets, 1)
assert.Equal(t, allPlatformIds[0], assets[0].PlatformIds[0])
assert.Equal(t, allPlatformIds[0], assets[0].Asset.PlatformIds[0])
}

func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) {
Expand Down Expand Up @@ -110,7 +110,7 @@ func TestDiscoveredAssets_GetAssetsByPlatformID_Empty(t *testing.T) {
assert.Len(t, assets, 10)
platformIds := []string{}
for _, a := range assets {
platformIds = append(platformIds, a.PlatformIds[0])
platformIds = append(platformIds, a.Asset.PlatformIds[0])
}
assert.ElementsMatch(t, allPlatformIds, platformIds)
}
Expand Down
5 changes: 0 additions & 5 deletions providers/aws/resources/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,6 @@ func Discover(runtime *plugin.Runtime, filters connection.DiscoveryFilters) (*in
Assets: []*inventory.Asset{},
}}

if (conn.Conf == nil || len(conn.Conf.Discover.Targets) == 0) && conn.Asset() != nil {
in.Spec.Assets = append(in.Spec.Assets, conn.Asset())
return in, nil
}

res, err := NewResource(runtime, "aws.account", map[string]*llx.RawData{"id": llx.StringData("aws.account/" + conn.AccountId())})
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion providers/aws/resources/discovery_conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func accountAsset(conn *connection.AwsConnection, awsAccount *mqlAwsAccount) *in
PlatformIds: []string{id},
Name: name,
Platform: connection.GetPlatformForObject(""),
Connections: []*inventory.Config{conn.Conf},
Connections: []*inventory.Config{conn.Conf.Clone(inventory.WithoutDiscovery(), inventory.WithParentConnectionId(conn.Conf.Id))},
Options: conn.ConnectionOptions(),
}
}
Expand Down

0 comments on commit 6edda62

Please sign in to comment.