Skip to content

Commit

Permalink
🔧 expand EnsureProvider to cover IDs (#2306)
Browse files Browse the repository at this point in the history
Refactor the whole EnsureProvider flow so users can specify providers
via ID, connector name and connector type. All of them are useful:

1. Install provider from Defaults => just specify the ID
2. Install provider from CLI args (like "local") => just specify name
3. Install provider from connection type => doh

Signed-off-by: Dominik Richter <[email protected]>
  • Loading branch information
arlimus authored Oct 20, 2023
1 parent bce2b46 commit 216f9f8
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 46 deletions.
2 changes: 1 addition & 1 deletion cli/providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func AttachCLIs(rootCmd *cobra.Command, commands ...*Command) error {

connectorName, autoUpdate := detectConnectorName(os.Args, rootCmd, commands, existing)
if connectorName != "" {
if _, err := providers.EnsureProvider(connectorName, "", autoUpdate, existing); err != nil {
if _, err := providers.EnsureProvider(providers.ProviderLookup{ConnName: connectorName}, autoUpdate, existing); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion cli/sysinfo/sysinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func GatherSystemInfo(opts ...SystemInfoOption) (*SystemInfo, error) {
cfg.runtime = providers.Coordinator.NewRuntime()

// init runtime
if _, err := providers.EnsureProvider("local", "", true, nil); err != nil {
if _, err := providers.EnsureProvider(providers.ProviderLookup{ConnName: "local"}, true, nil); err != nil {
return nil, err
}
if err := cfg.runtime.UseProvider(providers.DefaultOsID); err != nil {
Expand Down
115 changes: 72 additions & 43 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,71 @@ func init() {
LastProviderInstall = time.Now().Unix()
}

type ProviderLookup struct {
ID string
ConnName string
ConnType string
}

func (s ProviderLookup) String() string {
res := []string{}
if s.ID != "" {
res = append(res, "id="+s.ID)
}
if s.ConnName != "" {
res = append(res, "name="+s.ConnName)
}
if s.ConnType != "" {
res = append(res, "name="+s.ConnType)
}
return strings.Join(res, " ")
}

type Providers map[string]*Provider

// Lookup a provider in this list. If you search via ProviderID we will
// try to find the exact provider. Otherwise we will try to find a matching
// connector type first and name second.
func (p Providers) Lookup(search ProviderLookup) *Provider {
if search.ID != "" {
return p[search.ID]
}

if search.ConnType != "" {
for _, provider := range p {
if slices.Contains(provider.ConnectionTypes, search.ConnType) {
return provider
}
for i := range provider.Connectors {
if slices.Contains(provider.Connectors[i].Aliases, search.ConnType) {
return provider
}
}
}
}

if search.ConnName != "" {
for _, provider := range p {
for i := range provider.Connectors {
if provider.Connectors[i].Name == search.ConnName {
return provider
}
if slices.Contains(provider.Connectors[i].Aliases, search.ConnName) {
return provider
}
}
}
}

return nil
}

func (p Providers) Add(nu *Provider) {
if nu != nil {
p[nu.ID] = nu
}
}

type Provider struct {
*plugin.Provider
Schema *resources.Schema
Expand Down Expand Up @@ -169,15 +232,16 @@ func ListAll() ([]*Provider, error) {

// EnsureProvider makes sure that a given provider exists and returns it.
// You can supply providers either via:
// 1. connectorName, which is what you see in the CLI e.g. "local", "ssh", ...
// 2. connectorType, which is how assets define the connector type when
// 1. providerID, which universally identifies it, e.g. "go.mondoo.com/cnquery/v9/providers/os"
// 2. connectorName, which is what you see in the CLI e.g. "local", "ssh", ...
// 3. connectorType, which is how assets define the connector type when
// they are moved between discovery and execution, e.g. "registry-image".
//
// If you disable autoUpdate, it will neither update NOR install missing providers.
//
// If you don't supply existing providers, it will look for alist of all
// active providers first.
func EnsureProvider(connectorName string, connectorType string, autoUpdate bool, existing Providers) (*Provider, error) {
func EnsureProvider(search ProviderLookup, autoUpdate bool, existing Providers) (*Provider, error) {
if existing == nil {
var err error
existing, err = ListActive()
Expand All @@ -186,31 +250,32 @@ func EnsureProvider(connectorName string, connectorType string, autoUpdate bool,
}
}

provider := existing.ForConnection(connectorName, connectorType)
provider := existing.Lookup(search)
if provider != nil {
return provider, nil
}

if connectorName == "mock" || connectorType == "mock" {
if search.ID == mockProvider.ID || search.ConnName == "mock" || search.ConnType == "mock" {
existing.Add(&mockProvider)
return &mockProvider, nil
}

upstream := DefaultProviders.ForConnection(connectorName, connectorType)
upstream := DefaultProviders.Lookup(search)
if upstream == nil {
// we can't find any provider for this connector in our default set
// FIXME: This causes a panic in the CLI, we should handle this better
return nil, nil
}

if !autoUpdate {
return nil, errors.New("cannot find installed provider for connection " + connectorName)
return nil, errors.New("cannot find installed provider for " + search.String())
}

nu, err := Install(upstream.Name, "")
if err != nil {
return nil, err
}

existing.Add(nu)
PrintInstallResults([]*Provider{nu})
return nu, nil
Expand Down Expand Up @@ -626,42 +691,6 @@ func (p *Provider) binPath() string {
return filepath.Join(p.Path, name)
}

func (p Providers) ForConnection(name string, typ string) *Provider {
if name != "" {
for _, provider := range p {
for i := range provider.Connectors {
if provider.Connectors[i].Name == name {
return provider
}
if slices.Contains(provider.Connectors[i].Aliases, name) {
return provider
}
}
}
}

if typ != "" {
for _, provider := range p {
if slices.Contains(provider.ConnectionTypes, typ) {
return provider
}
for i := range provider.Connectors {
if slices.Contains(provider.Connectors[i].Aliases, typ) {
return provider
}
}
}
}

return nil
}

func (p Providers) Add(nu *Provider) {
if nu != nil {
p[nu.ID] = nu
}
}

func MustLoadSchema(name string, data []byte) *resources.Schema {
var res resources.Schema
if err := json.Unmarshal(data, &res); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion providers/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (r *Runtime) DetectProvider(asset *inventory.Asset) error {
conn.Type = inventory.ConnBackendToType(conn.Backend)
}

provider, err := EnsureProvider("", conn.Type, true, r.coordinator.Providers)
provider, err := EnsureProvider(ProviderLookup{ConnType: conn.Type}, true, r.coordinator.Providers)
if err != nil {
errs.Add(err)
continue
Expand Down

0 comments on commit 216f9f8

Please sign in to comment.