diff --git a/providers/aws/connection/connection.go b/providers/aws/connection/connection.go index a53f4f15e3..88ecafed95 100644 --- a/providers/aws/connection/connection.go +++ b/providers/aws/connection/connection.go @@ -31,6 +31,13 @@ type AwsConnection struct { connectionOptions map[string]string } +func NewMockConnection(id uint32, asset *inventory.Asset, conf *inventory.Config) *AwsConnection { + return &AwsConnection{ + id: id, + asset: asset, + } +} + func NewAwsConnection(id uint32, asset *inventory.Asset, conf *inventory.Config) (*AwsConnection, error) { log.Debug().Msg("new aws connection") // check flags for connection options diff --git a/providers/aws/provider/provider.go b/providers/aws/provider/provider.go index 5c6c2b76da..0c6e7dcb32 100644 --- a/providers/aws/provider/provider.go +++ b/providers/aws/provider/provider.go @@ -108,7 +108,33 @@ func (s *Service) Shutdown(req *plugin.ShutdownReq) (*plugin.ShutdownRes, error) } func (s *Service) MockConnect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) { - return nil, errors.New("mock connect not yet implemented") + if req == nil || req.Asset == nil { + return nil, errors.New("no connection data provided") + } + + asset := &inventory.Asset{ + PlatformIds: req.Asset.PlatformIds, + Platform: req.Asset.Platform, + Connections: []*inventory.Config{{ + Type: "mock", + }}, + } + + conn, err := s.connect(&plugin.ConnectReq{ + Features: req.Features, + Upstream: req.Upstream, + Asset: asset, + }, callback) + if err != nil { + return nil, err + } + + return &plugin.ConnectRes{ + Id: uint32(conn.(shared.Connection).ID()), + Name: conn.(shared.Connection).Name(), + Asset: asset, + Inventory: nil, + }, nil } func (s *Service) Connect(req *plugin.ConnectReq, callback plugin.ProviderCallback) (*plugin.ConnectRes, error) { @@ -161,6 +187,10 @@ func (s *Service) connect(req *plugin.ConnectReq, callback plugin.ProviderCallba var err error switch conf.Type { + case "mock": + s.lastConnectionID++ + conn = connection.NewMockConnection(s.lastConnectionID, asset, conf) + case string(awsec2ebsconn.EBSConnectionType): s.lastConnectionID++ conn, err = awsec2ebsconn.NewAwsEbsConnection(s.lastConnectionID, conf, asset)