Skip to content

Commit

Permalink
🐛 Load provider flags from environment variables (#4847)
Browse files Browse the repository at this point in the history
  • Loading branch information
afiune authored Nov 13, 2024
1 parent 7748bcf commit 0b4c641
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
31 changes: 9 additions & 22 deletions cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ package providers

import (
"encoding/json"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
"os"
"strings"

Expand All @@ -21,6 +19,8 @@ import (
"go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/recording"
"go.mondoo.com/cnquery/v11/types"
"go.mondoo.com/cnquery/v11/utils/piped"
"go.mondoo.com/ranger-rpc/status"
)

type Command struct {
Expand Down Expand Up @@ -318,35 +318,22 @@ func attachFlags(flagset *pflag.FlagSet, flags []plugin.Flag) {
}
}

func getFlagValue(flag plugin.Flag, cmd *cobra.Command) *llx.Primitive {
func getFlagValue(flag plugin.Flag) *llx.Primitive {
switch flag.Type {
case plugin.FlagType_Bool:
v, err := cmd.Flags().GetBool(flag.Long)
if err == nil {
return llx.BoolPrimitive(v)
}
log.Warn().Err(err).Msg("failed to get flag " + flag.Long)
return llx.BoolPrimitive(viper.GetBool(flag.Long))
case plugin.FlagType_Int:
if v, err := cmd.Flags().GetInt(flag.Long); err == nil {
return llx.IntPrimitive(int64(v))
}
return llx.IntPrimitive(viper.GetInt64(flag.Long))
case plugin.FlagType_String:
if v, err := cmd.Flags().GetString(flag.Long); err == nil {
return llx.StringPrimitive(v)
}
return llx.StringPrimitive(viper.GetString(flag.Long))
case plugin.FlagType_List:
if v, err := cmd.Flags().GetStringSlice(flag.Long); err == nil {
return llx.ArrayPrimitiveT(v, llx.StringPrimitive, types.String)
}
return llx.ArrayPrimitiveT(viper.GetStringSlice(flag.Long), llx.StringPrimitive, types.String)
case plugin.FlagType_KeyValue:
if v, err := cmd.Flags().GetStringToString(flag.Long); err == nil {
return llx.MapPrimitiveT(v, llx.StringPrimitive, types.String)
}
return llx.MapPrimitiveT(viper.GetStringMapString(flag.Long), llx.StringPrimitive, types.String)
default:
log.Warn().Msg("unknown flag type for " + flag.Long)
return nil
}
return nil
}

func setConnector(provider *plugin.Provider, connector *plugin.Connector, run func(*cobra.Command, *providers.Runtime, *plugin.ParseCLIRes), cmd *cobra.Command) {
Expand Down Expand Up @@ -421,7 +408,7 @@ func setConnector(provider *plugin.Provider, connector *plugin.Connector, run fu
continue
}

if v := getFlagValue(flag, cmd); v != nil {
if v := getFlagValue(flag); v != nil {
flagVals[flag.Long] = v
}
}
Expand Down
47 changes: 44 additions & 3 deletions test/providers/os_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
package providers

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
"log"
"os"
"os/exec"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mondoo.com/cnquery/v11/test"
)

var once sync.Once
Expand Down Expand Up @@ -185,3 +186,43 @@ func TestOsProviderSharedTests(t *testing.T) {
}
}
}

func TestProvidersEnvVarsLoading(t *testing.T) {
t.Run("command WITHOUT path should not find any package", func(t *testing.T) {
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// No packages
assert.Empty(t, c)
})
t.Run("command WITH path should find packages", func(t *testing.T) {
os.Setenv("MONDOO_PATH", "./testdata/fs")
defer os.Unsetenv("MONDOO_PATH")
// Note we are not passing the flag "--path ./testdata/fs"
r := test.NewCliTestRunner("./cnquery", "run", "fs", "-c", mqlPackagesQuery, "-j")
err := r.Run()
require.NoError(t, err)
assert.Equal(t, 0, r.ExitCode())
assert.NotNil(t, r.Stdout())
assert.NotNil(t, r.Stderr())

var c mqlPackages
err = r.Json(&c)
assert.NoError(t, err)

// Should have packages
if assert.NotEmpty(t, c) {
x := c[0]
assert.NotNil(t, x.Packages)
assert.True(t, len(x.Packages) > 0)
}
})
}

0 comments on commit 0b4c641

Please sign in to comment.