diff --git a/internal/cast/cast.go b/internal/castai/castai.go similarity index 99% rename from internal/cast/cast.go rename to internal/castai/castai.go index e4819217..def86cc0 100644 --- a/internal/cast/cast.go +++ b/internal/castai/castai.go @@ -1,20 +1,22 @@ //go:generate mockgen -destination ./mock/client.go . Client -package cast +package castai import ( "bytes" - "castai-agent/internal/config" "context" "encoding/json" "fmt" - "github.com/go-resty/resty/v2" - "github.com/sirupsen/logrus" "io" "mime/multipart" "net/http" "net/textproto" "net/url" "time" + + "github.com/go-resty/resty/v2" + "github.com/sirupsen/logrus" + + "castai-agent/internal/config" ) const ( diff --git a/internal/cast/cast_test.go b/internal/castai/castai_test.go similarity index 99% rename from internal/cast/cast_test.go rename to internal/castai/castai_test.go index a6ab365a..3cb0c2ba 100644 --- a/internal/cast/cast_test.go +++ b/internal/castai/castai_test.go @@ -1,4 +1,4 @@ -package cast +package castai import ( "castai-agent/internal/services/collector" diff --git a/internal/cast/mock/client.go b/internal/castai/mock/client.go similarity index 84% rename from internal/cast/mock/client.go rename to internal/castai/mock/client.go index 95ca04c8..776a4d86 100644 --- a/internal/cast/mock/client.go +++ b/internal/castai/mock/client.go @@ -1,11 +1,11 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: castai-agent/internal/cast (interfaces: Client) +// Source: castai-agent/internal/castai (interfaces: Client) -// Package mock_cast is a generated GoMock package. -package mock_cast +// Package mock_castai is a generated GoMock package. +package mock_castai import ( - cast "castai-agent/internal/cast" + castai "castai-agent/internal/castai" context "context" reflect "reflect" @@ -36,10 +36,10 @@ func (m *MockClient) EXPECT() *MockClientMockRecorder { } // RegisterCluster mocks base method. -func (m *MockClient) RegisterCluster(arg0 context.Context, arg1 *cast.RegisterClusterRequest) (*cast.RegisterClusterResponse, error) { +func (m *MockClient) RegisterCluster(arg0 context.Context, arg1 *castai.RegisterClusterRequest) (*castai.RegisterClusterResponse, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RegisterCluster", arg0, arg1) - ret0, _ := ret[0].(*cast.RegisterClusterResponse) + ret0, _ := ret[0].(*castai.RegisterClusterResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -51,7 +51,7 @@ func (mr *MockClientMockRecorder) RegisterCluster(arg0, arg1 interface{}) *gomoc } // SendClusterSnapshot mocks base method. -func (m *MockClient) SendClusterSnapshot(arg0 context.Context, arg1 *cast.Snapshot) error { +func (m *MockClient) SendClusterSnapshot(arg0 context.Context, arg1 *castai.Snapshot) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendClusterSnapshot", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/internal/cast/types.go b/internal/castai/types.go similarity index 98% rename from internal/cast/types.go rename to internal/castai/types.go index 7c5e836a..141c752d 100644 --- a/internal/cast/types.go +++ b/internal/castai/types.go @@ -1,4 +1,4 @@ -package cast +package castai import "castai-agent/internal/services/collector" diff --git a/internal/config/config.go b/internal/config/config.go index 69a8086c..105be61e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "github.com/spf13/viper" ) @@ -9,6 +10,7 @@ type Config struct { API API Kubeconfig string Provider string + CASTAI *CASTAI EKS *EKS } @@ -17,6 +19,11 @@ type API struct { URL string } +type CASTAI struct { + ClusterID string + OrganizationID string +} + type EKS struct { AccountID string Region string @@ -25,6 +32,7 @@ type EKS struct { var cfg *Config +// Get configuration bound to environment variables. func Get() Config { if cfg != nil { return *cfg @@ -37,6 +45,9 @@ func Get() Config { _ = viper.BindEnv("provider") + _ = viper.BindEnv("castai.clusterid", "CASTAI_CLUSTER_ID") + _ = viper.BindEnv("castai.organizationid", "CASTAI_ORGANIZATION_ID") + _ = viper.BindEnv("eks.accountid", "EKS_ACCOUNT_ID") _ = viper.BindEnv("eks.region", "EKS_REGION") _ = viper.BindEnv("eks.clustername", "EKS_CLUSTER_NAME") @@ -53,6 +64,15 @@ func Get() Config { required("API_URL") } + if cfg.CASTAI != nil { + if cfg.CASTAI.ClusterID == "" { + requiredDiscoveryDisabled("CASTAI_CLUSTER_ID") + } + if cfg.CASTAI.OrganizationID == "" { + requiredDiscoveryDisabled("CASTAI_ORGANIZATION_ID") + } + } + if cfg.EKS != nil { if cfg.EKS.AccountID == "" { requiredDiscoveryDisabled("EKS_ACCOUNT_ID") @@ -68,6 +88,11 @@ func Get() Config { return *cfg } +// Reset is used only for unit testing to reset configuration and rebind variables. +func Reset() { + cfg = nil +} + func required(variable string) { panic(fmt.Errorf("env variable %s is required", variable)) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index e2bcfc6a..02292dcc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,9 +1,10 @@ package config import ( - "github.com/stretchr/testify/require" "os" "testing" + + "github.com/stretchr/testify/require" ) func TestConfig(t *testing.T) { diff --git a/internal/services/collector/collector.go b/internal/services/collector/collector.go index 4a60bfa7..d472da60 100644 --- a/internal/services/collector/collector.go +++ b/internal/services/collector/collector.go @@ -4,12 +4,13 @@ package collector import ( "context" "fmt" + "regexp" + "strconv" + "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/version" "k8s.io/client-go/kubernetes" - "regexp" - "strconv" ) // Collector is responsible for gathering K8s data from the cluster. diff --git a/internal/services/providers/castai/castai.go b/internal/services/providers/castai/castai.go new file mode 100644 index 00000000..16fb4c42 --- /dev/null +++ b/internal/services/providers/castai/castai.go @@ -0,0 +1,61 @@ +package castai + +import ( + "context" + + "github.com/sirupsen/logrus" + "k8s.io/api/core/v1" + + "castai-agent/internal/castai" + "castai-agent/internal/config" + "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" +) + +const ( + Name = "castai" +) + +func New(_ context.Context, log logrus.FieldLogger) (types.Provider, error) { + return &Provider{log: log}, nil +} + +type Provider struct { + log logrus.FieldLogger +} + +func (p *Provider) RegisterCluster(_ context.Context, _ castai.Client) (*types.ClusterRegistration, error) { + cfg := config.Get().CASTAI + return &types.ClusterRegistration{ + ClusterID: cfg.ClusterID, + OrganizationID: cfg.OrganizationID, + }, nil +} + +func (p *Provider) Name() string { + return Name +} + +func (p *Provider) FilterSpot(_ context.Context, nodes []*v1.Node) ([]*v1.Node, error) { + var spots []*v1.Node + + for _, n := range nodes { + if val, ok := n.ObjectMeta.Labels[labels.Spot]; ok && val == "true" { + spots = append(spots, n) + } + } + + return spots, nil +} + +func (p *Provider) AccountID(_ context.Context) (string, error) { + return "", nil +} + +func (p *Provider) ClusterName(_ context.Context) (string, error) { + return "", nil +} + +func (p *Provider) ClusterRegion(_ context.Context) (string, error) { + return "", nil +} diff --git a/internal/services/providers/eks/client/aws.go b/internal/services/providers/eks/client/aws.go index 90016f8e..380d359a 100644 --- a/internal/services/providers/eks/client/aws.go +++ b/internal/services/providers/eks/client/aws.go @@ -5,14 +5,15 @@ import ( "context" "errors" "fmt" + "math" + "strings" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" "github.com/sirupsen/logrus" "k8s.io/utils/pointer" - "math" - "strings" ) // Client is an abstraction on the AWS SDK to enable easier mocking and manipulation of request data. diff --git a/internal/services/providers/eks/eks.go b/internal/services/providers/eks/eks.go index 25410a20..370c8b79 100644 --- a/internal/services/providers/eks/eks.go +++ b/internal/services/providers/eks/eks.go @@ -1,15 +1,18 @@ package eks import ( - "castai-agent/internal/cast" - "castai-agent/internal/config" - "castai-agent/internal/services/providers/eks/client" "context" "fmt" + "time" + "github.com/patrickmn/go-cache" "github.com/sirupsen/logrus" "k8s.io/api/core/v1" - "time" + + "castai-agent/internal/castai" + "castai-agent/internal/config" + "castai-agent/internal/services/providers/eks/client" + "castai-agent/internal/services/providers/types" ) const ( @@ -17,7 +20,7 @@ const ( ) // New configures and returns an EKS provider. -func New(ctx context.Context, log logrus.FieldLogger) (*Provider, error) { +func New(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { log = log.WithField("provider", Name) var opts []client.Opt @@ -48,6 +51,40 @@ type Provider struct { spotCache *cache.Cache } +func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { + cn, err := p.awsClient.GetClusterName(ctx) + if err != nil { + return nil, fmt.Errorf("getting cluster name: %w", err) + } + r, err := p.awsClient.GetRegion(ctx) + if err != nil { + return nil, fmt.Errorf("getting region: %w", err) + } + accID, err := p.awsClient.GetAccountID(ctx) + if err != nil { + return nil, fmt.Errorf("getting account id: %w", err) + } + + req := &castai.RegisterClusterRequest{ + Name: *cn, + EKS: castai.EKSParams{ + ClusterName: *cn, + Region: *r, + AccountID: *accID, + }, + } + + resp, err := client.RegisterCluster(ctx, req) + if err != nil { + return nil, fmt.Errorf("requesting castai api: %w", err) + } + + return &types.ClusterRegistration{ + ClusterID: resp.ID, + OrganizationID: resp.OrganizationID, + }, nil +} + func (p *Provider) FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node, error) { if p.spotCache == nil { p.spotCache = cache.New(60*time.Minute, 10*time.Minute) @@ -124,7 +161,7 @@ func (p *Provider) AccountID(ctx context.Context) (string, error) { return *accID, nil } -func (p *Provider) RegisterClusterRequest(ctx context.Context) (*cast.RegisterClusterRequest, error) { +func (p *Provider) RegisterClusterRequest(ctx context.Context) (*castai.RegisterClusterRequest, error) { cn, err := p.awsClient.GetClusterName(ctx) if err != nil { return nil, fmt.Errorf("getting cluster name: %w", err) @@ -137,9 +174,9 @@ func (p *Provider) RegisterClusterRequest(ctx context.Context) (*cast.RegisterCl if err != nil { return nil, fmt.Errorf("getting account id: %w", err) } - return &cast.RegisterClusterRequest{ + return &castai.RegisterClusterRequest{ Name: *cn, - EKS: cast.EKSParams{ + EKS: castai.EKSParams{ ClusterName: *cn, Region: *r, AccountID: *accID, diff --git a/internal/services/providers/eks/eks_test.go b/internal/services/providers/eks/eks_test.go index a4f3d147..b70f7be2 100644 --- a/internal/services/providers/eks/eks_test.go +++ b/internal/services/providers/eks/eks_test.go @@ -1,22 +1,29 @@ package eks import ( - "castai-agent/internal/cast" - mock_client "castai-agent/internal/services/providers/eks/client/mock" "context" + "testing" + "github.com/aws/aws-sdk-go/service/ec2" "github.com/golang/mock/gomock" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/utils/pointer" - "testing" + + "castai-agent/internal/castai" + mock_castai "castai-agent/internal/castai/mock" + mock_client "castai-agent/internal/services/providers/eks/client/mock" + "castai-agent/internal/services/providers/types" ) -func TestProvider_RegisterClusterRequest(t *testing.T) { +func TestProvider_RegisterCluster(t *testing.T) { ctx := context.Background() - awsClient := mock_client.NewMockClient(gomock.NewController(t)) + mockctrl := gomock.NewController(t) + awsClient := mock_client.NewMockClient(mockctrl) + castClient := mock_castai.NewMockClient(mockctrl) p := &Provider{ log: logrus.New(), @@ -27,16 +34,26 @@ func TestProvider_RegisterClusterRequest(t *testing.T) { awsClient.EXPECT().GetRegion(ctx).Return(pointer.StringPtr("eu-central-1"), nil) awsClient.EXPECT().GetAccountID(ctx).Return(pointer.StringPtr("id"), nil) - expected := &cast.RegisterClusterRequest{ + expectedReq := &castai.RegisterClusterRequest{ Name: "test", - EKS: cast.EKSParams{ + EKS: castai.EKSParams{ ClusterName: "test", Region: "eu-central-1", AccountID: "id", }, } - got, err := p.RegisterClusterRequest(ctx) + expected := &types.ClusterRegistration{ + ClusterID: uuid.New().String(), + OrganizationID: uuid.New().String(), + } + + castClient.EXPECT().RegisterCluster(ctx, expectedReq).Return(&castai.RegisterClusterResponse{Cluster: castai.Cluster{ + ID: expected.ClusterID, + OrganizationID: expected.OrganizationID, + }}, nil) + + got, err := p.RegisterCluster(ctx, castClient) require.NoError(t, err) require.Equal(t, expected, got) diff --git a/internal/services/providers/providers.go b/internal/services/providers/providers.go index 9d56f08d..06e68cb6 100644 --- a/internal/services/providers/providers.go +++ b/internal/services/providers/providers.go @@ -1,53 +1,37 @@ -//go:generate mockgen -destination ./mock/provider.go . Provider package providers import ( - "castai-agent/internal/cast" - "castai-agent/internal/config" - "castai-agent/internal/services/providers/eks" "context" "errors" "fmt" + "github.com/sirupsen/logrus" - v1 "k8s.io/api/core/v1" -) -// Provider is an abstraction for various CAST AI supported K8s providers, like EKS, GKE, etc. -type Provider interface { - // RegisterClusterRequest retrieves all the required data needed to correctly register the cluster with CAST AI. - RegisterClusterRequest(ctx context.Context) (*cast.RegisterClusterRequest, error) - // FilterSpot returns a list of nodes which are configured as spot/preemtible instances. - FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node, error) - // Name of the provider. - Name() string - // AccountID of the EC2 instance. - // Deprecated: snapshot should not include cluster metadata as it already is known via register cluster request. - AccountID(ctx context.Context) (string, error) - // ClusterName of the of the EKS cluster. - // Deprecated: snapshot should not include cluster name as it already is known via register cluster request. - ClusterName(ctx context.Context) (string, error) - // ClusterRegion of the EC2 instance. - // Deprecated: snapshot should not include cluster metadata as it already is known via register cluster request. - ClusterRegion(ctx context.Context) (string, error) -} + "castai-agent/internal/config" + "castai-agent/internal/services/providers/castai" + "castai-agent/internal/services/providers/eks" + "castai-agent/internal/services/providers/types" +) -func GetProvider(ctx context.Context, log logrus.FieldLogger) (p Provider, err error) { +func GetProvider(ctx context.Context, log logrus.FieldLogger) (p types.Provider, err error) { cfg := config.Get() - switch cfg.Provider { - case "": - return dynamicProvider(ctx, log) - case eks.Name: - return eks.New(ctx, log) - default: - return nil, fmt.Errorf("unknown provider %q", cfg.Provider) + + if cfg.Provider == castai.Name || cfg.CASTAI != nil { + return castai.New(ctx, log) } -} -func dynamicProvider(ctx context.Context, log logrus.FieldLogger) (Provider, error) { - if config.Get().EKS != nil { + if cfg.Provider == eks.Name || cfg.EKS != nil { return eks.New(ctx, log) } + if cfg.Provider == "" { + return dynamicProvider(ctx, log) + } + + return nil, fmt.Errorf("unknown provider %q", cfg.Provider) +} + +func dynamicProvider(ctx context.Context, log logrus.FieldLogger) (types.Provider, error) { log.Info("using cluster provider discovery") if p, err := eks.New(ctx, log); err != nil { @@ -56,5 +40,5 @@ func dynamicProvider(ctx context.Context, log logrus.FieldLogger) (Provider, err return p, nil } - return nil, errors.New("none of the supported providers were able to initialize") + return nil, errors.New("none of the providers supporting discovery were able to initialize") } diff --git a/internal/services/providers/providers_test.go b/internal/services/providers/providers_test.go index cc25f5c3..1aa1c6a5 100644 --- a/internal/services/providers/providers_test.go +++ b/internal/services/providers/providers_test.go @@ -1,27 +1,45 @@ package providers import ( - "castai-agent/internal/services/providers/eks" "context" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" "os" "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "castai-agent/internal/config" + "castai-agent/internal/services/providers/castai" + "castai-agent/internal/services/providers/eks" ) func TestGetProvider(t *testing.T) { - require.NoError(t, os.Setenv("API_KEY", "api-key")) - require.NoError(t, os.Setenv("API_URL", "test")) + t.Run("should return castai", func(t *testing.T) { + t.Cleanup(config.Reset) + + require.NoError(t, os.Setenv("API_KEY", "api-key")) + require.NoError(t, os.Setenv("API_URL", "test")) + require.NoError(t, os.Setenv("PROVIDER", "castai")) + + got, err := GetProvider(context.Background(), logrus.New()) + + require.NoError(t, err) + require.IsType(t, &castai.Provider{}, got) + }) t.Run("should return eks", func(t *testing.T) { + t.Cleanup(config.Reset) + + require.NoError(t, os.Setenv("API_KEY", "api-key")) + require.NoError(t, os.Setenv("API_URL", "test")) require.NoError(t, os.Setenv("PROVIDER", "eks")) require.NoError(t, os.Setenv("EKS_CLUSTER_NAME", "test")) require.NoError(t, os.Setenv("EKS_ACCOUNT_ID", "accountID")) require.NoError(t, os.Setenv("EKS_REGION", "eu-central-1")) - p, err := GetProvider(context.Background(), logrus.New()) + got, err := GetProvider(context.Background(), logrus.New()) require.NoError(t, err) - require.IsType(t, &eks.Provider{}, p) + require.IsType(t, &eks.Provider{}, got) }) } diff --git a/internal/services/providers/mock/provider.go b/internal/services/providers/types/mock/provider.go similarity index 82% rename from internal/services/providers/mock/provider.go rename to internal/services/providers/types/mock/provider.go index c6490061..f15fa2c2 100644 --- a/internal/services/providers/mock/provider.go +++ b/internal/services/providers/types/mock/provider.go @@ -1,11 +1,12 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: castai-agent/internal/services/providers (interfaces: Provider) +// Source: castai-agent/internal/services/providers/types (interfaces: Provider) -// Package mock_providers is a generated GoMock package. -package mock_providers +// Package mock_types is a generated GoMock package. +package mock_types import ( - cast "castai-agent/internal/cast" + castai "castai-agent/internal/castai" + types "castai-agent/internal/services/providers/types" context "context" reflect "reflect" @@ -110,17 +111,17 @@ func (mr *MockProviderMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockProvider)(nil).Name)) } -// RegisterClusterRequest mocks base method. -func (m *MockProvider) RegisterClusterRequest(arg0 context.Context) (*cast.RegisterClusterRequest, error) { +// RegisterCluster mocks base method. +func (m *MockProvider) RegisterCluster(arg0 context.Context, arg1 castai.Client) (*types.ClusterRegistration, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterClusterRequest", arg0) - ret0, _ := ret[0].(*cast.RegisterClusterRequest) + ret := m.ctrl.Call(m, "RegisterCluster", arg0, arg1) + ret0, _ := ret[0].(*types.ClusterRegistration) ret1, _ := ret[1].(error) return ret0, ret1 } -// RegisterClusterRequest indicates an expected call of RegisterClusterRequest. -func (mr *MockProviderMockRecorder) RegisterClusterRequest(arg0 interface{}) *gomock.Call { +// RegisterCluster indicates an expected call of RegisterCluster. +func (mr *MockProviderMockRecorder) RegisterCluster(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterClusterRequest", reflect.TypeOf((*MockProvider)(nil).RegisterClusterRequest), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterCluster", reflect.TypeOf((*MockProvider)(nil).RegisterCluster), arg0, arg1) } diff --git a/internal/services/providers/types/types.go b/internal/services/providers/types/types.go new file mode 100644 index 00000000..edfa2215 --- /dev/null +++ b/internal/services/providers/types/types.go @@ -0,0 +1,35 @@ +//go:generate mockgen -destination ./mock/provider.go . Provider +package types + +import ( + "context" + + v1 "k8s.io/api/core/v1" + + castclient "castai-agent/internal/castai" +) + +// Provider is an abstraction for various CAST AI supported K8s providers, like EKS, GKE, etc. +type Provider interface { + // RegisterCluster retrieves cluster registration data needed to correctly identify the cluster. + RegisterCluster(ctx context.Context, client castclient.Client) (*ClusterRegistration, error) + // FilterSpot returns a list of nodes which are configured as spot/preemtible instances. + FilterSpot(ctx context.Context, nodes []*v1.Node) ([]*v1.Node, error) + // Name of the provider. + Name() string + // AccountID of the EC2 instance. + // Deprecated: snapshot should not include cluster metadata as it already is known via register cluster request. + AccountID(ctx context.Context) (string, error) + // ClusterName of the of the EKS cluster. + // Deprecated: snapshot should not include cluster name as it already is known via register cluster request. + ClusterName(ctx context.Context) (string, error) + // ClusterRegion of the EC2 instance. + // Deprecated: snapshot should not include cluster metadata as it already is known via register cluster request. + ClusterRegion(ctx context.Context) (string, error) +} + +// ClusterRegistration holds information needed to identify the cluster. +type ClusterRegistration struct { + ClusterID string + OrganizationID string +} diff --git a/main.go b/main.go index d2362949..10695938 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,25 @@ package main import ( - "castai-agent/internal/cast" - "castai-agent/internal/config" - "castai-agent/internal/services/providers" "context" "fmt" "io/ioutil" - "k8s.io/api/core/v1" - "k8s.io/client-go/kubernetes" - "sigs.k8s.io/controller-runtime/pkg/manager/signals" "strings" "time" "github.com/sirupsen/logrus" + "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" + "sigs.k8s.io/controller-runtime/pkg/manager/signals" + "castai-agent/internal/castai" + "castai-agent/internal/config" "castai-agent/internal/services/collector" - - "k8s.io/client-go/rest" + "castai-agent/internal/services/providers" + "castai-agent/internal/services/providers/types" + "castai-agent/pkg/labels" ) func main() { @@ -38,14 +39,9 @@ func run(ctx context.Context, log logrus.FieldLogger) error { log = log.WithField("provider", provider.Name()) - registerClusterReq, err := provider.RegisterClusterRequest(ctx) - if err != nil { - return fmt.Errorf("creating register cluster request: %w", err) - } - - castclient := cast.NewClient(log, cast.NewDefaultClient()) + castclient := castai.NewClient(log, castai.NewDefaultClient()) - c, err := castclient.RegisterCluster(ctx, registerClusterReq) + reg, err := provider.RegisterCluster(ctx, castclient) if err != nil { return fmt.Errorf("registering cluster: %w", err) } @@ -70,7 +66,7 @@ func run(ctx context.Context, log logrus.FieldLogger) error { defer ticker.Stop() for { - if err := collect(ctx, log, c, col, provider, castclient); err != nil { + if err := collect(ctx, log, reg, col, provider, castclient); err != nil { log.Errorf("collecting snapshot data: %v", err) } @@ -86,10 +82,10 @@ func run(ctx context.Context, log logrus.FieldLogger) error { func collect( ctx context.Context, log logrus.FieldLogger, - c *cast.RegisterClusterResponse, + reg *types.ClusterRegistration, col collector.Collector, - provider providers.Provider, - castclient cast.Client, + provider types.Provider, + castclient castai.Client, ) error { cd, err := col.Collect(ctx) if err != nil { @@ -111,9 +107,9 @@ func collect( return fmt.Errorf("getting cluster region: %w", err) } - snap := &cast.Snapshot{ - ClusterID: c.Cluster.ID, - OrganizationID: c.Cluster.OrganizationID, + snap := &castai.Snapshot{ + ClusterID: reg.ClusterID, + OrganizationID: reg.OrganizationID, ClusterProvider: strings.ToUpper(provider.Name()), AccountID: accountID, ClusterName: clusterName, @@ -136,7 +132,7 @@ func collect( return nil } -func addSpotLabel(ctx context.Context, provider providers.Provider, nodes *v1.NodeList) error { +func addSpotLabel(ctx context.Context, provider types.Provider, nodes *v1.NodeList) error { nodeMap := make(map[string]*v1.Node, len(nodes.Items)) items := make([]*v1.Node, len(nodes.Items)) for i, node := range nodes.Items { @@ -150,7 +146,7 @@ func addSpotLabel(ctx context.Context, provider providers.Provider, nodes *v1.No } for _, node := range spotNodes { - nodeMap[node.Name].Labels["scheduling.cast.ai/spot"] = "true" + nodeMap[node.Name].Labels[labels.Spot] = "true" } return nil diff --git a/main_test.go b/main_test.go index 131af4e5..ef1da2aa 100644 --- a/main_test.go +++ b/main_test.go @@ -1,12 +1,16 @@ package main import ( - "castai-agent/internal/cast" - mock_cast "castai-agent/internal/cast/mock" + "context" + "testing" + + "castai-agent/internal/castai" + mock_castai "castai-agent/internal/castai/mock" "castai-agent/internal/services/collector" mock_collector "castai-agent/internal/services/collector/mock" - mock_providers "castai-agent/internal/services/providers/mock" - "context" + "castai-agent/internal/services/providers/types" + mock_types "castai-agent/internal/services/providers/types/mock" + "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/sirupsen/logrus" @@ -14,17 +18,19 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/version" - "testing" ) func TestCollect(t *testing.T) { ctx := context.Background() mockctrl := gomock.NewController(t) col := mock_collector.NewMockCollector(mockctrl) - provider := mock_providers.NewMockProvider(mockctrl) - castclient := mock_cast.NewMockClient(mockctrl) + provider := mock_types.NewMockProvider(mockctrl) + castclient := mock_castai.NewMockClient(mockctrl) - c := &cast.RegisterClusterResponse{Cluster: cast.Cluster{ID: uuid.New().String(), OrganizationID: uuid.New().String()}} + reg := &types.ClusterRegistration{ + ClusterID: uuid.New().String(), + OrganizationID: uuid.New().String(), + } spot := v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "spot", Labels: map[string]string{}}} onDemand := v1.Node{ObjectMeta: metav1.ObjectMeta{Name: "on-demand"}} @@ -39,9 +45,9 @@ func TestCollect(t *testing.T) { provider.EXPECT().Name().Return("eks") provider.EXPECT().FilterSpot(ctx, []*v1.Node{&spot, &onDemand}).Return([]*v1.Node{&spot}, nil) - castclient.EXPECT().SendClusterSnapshot(ctx, &cast.Snapshot{ - ClusterID: c.Cluster.ID, - OrganizationID: c.Cluster.OrganizationID, + castclient.EXPECT().SendClusterSnapshot(ctx, &castai.Snapshot{ + ClusterID: reg.ClusterID, + OrganizationID: reg.OrganizationID, AccountID: "accountID", ClusterProvider: "EKS", ClusterName: "clusterName", @@ -50,7 +56,7 @@ func TestCollect(t *testing.T) { ClusterVersion: "1.20", }).Return(nil) - err := collect(ctx, logrus.New(), c, col, provider, castclient) + err := collect(ctx, logrus.New(), reg, col, provider, castclient) require.NoError(t, err) diff --git a/pkg/labels/labels.go b/pkg/labels/labels.go new file mode 100644 index 00000000..d7c9b63d --- /dev/null +++ b/pkg/labels/labels.go @@ -0,0 +1,5 @@ +package labels + +const ( + Spot = "scheduling.cast.ai/spot" +)