diff --git a/pkg/providers/nutanix/validator.go b/pkg/providers/nutanix/validator.go index 1e23c924292f..3a14d6451da1 100644 --- a/pkg/providers/nutanix/validator.go +++ b/pkg/providers/nutanix/validator.go @@ -62,7 +62,7 @@ func (v *Validator) ValidateClusterSpec(ctx context.Context, spec *cluster.Spec, } for _, conf := range spec.NutanixMachineConfigs { - if err := v.ValidateMachineConfig(ctx, client, conf); err != nil { + if err := v.ValidateMachineConfig(ctx, client, spec.Cluster, conf); err != nil { return fmt.Errorf("failed to validate machine config: %v", err) } } @@ -249,7 +249,7 @@ func (v *Validator) validateMachineSpecs(machineSpec anywherev1.NutanixMachineCo } // ValidateMachineConfig validates the Prism Element cluster, subnet, and image for the machine. -func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, config *anywherev1.NutanixMachineConfig) error { +func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, cluster *anywherev1.Cluster, config *anywherev1.NutanixMachineConfig) error { if err := v.validateMachineSpecs(config.Spec); err != nil { return err } @@ -278,7 +278,19 @@ func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, co } } + if err := v.validateGPUInMachineConfig(cluster, config); err != nil { + return err + } + + return nil +} + +func (v *Validator) validateGPUInMachineConfig(cluster *anywherev1.Cluster, config *anywherev1.NutanixMachineConfig) error { if config.Spec.GPUs != nil { + if err := checkMachineConfigIsForWorker(config, cluster); err != nil { + return err + } + for _, gpu := range config.Spec.GPUs { if err := v.validateGPUConfig(gpu); err != nil { return err @@ -654,13 +666,13 @@ func createGetGpuModeFunc(gpuDeviceIDToMode map[int64]string, gpuNameToMode map[ } func (v *Validator) validateFreeGPU(ctx context.Context, v3Client Client, cluster *cluster.Spec) error { - res, err := v3Client.ListAllHost(ctx) - if err != nil || len(res.Entities) == 0 { - return fmt.Errorf("No GPUs found: %v", err) - } - if v.isGPURequested(cluster.NutanixMachineConfigs) { - err := v.validateGPUModeNotMixed(res.Entities, cluster) + res, err := v3Client.ListAllHost(ctx) + if err != nil || len(res.Entities) == 0 { + return fmt.Errorf("no GPUs found: %v", err) + } + + err = v.validateGPUModeNotMixed(res.Entities, cluster) if err != nil { return err } @@ -690,6 +702,24 @@ func (v *Validator) validateUpgradeRolloutStrategy(clusterSpec *cluster.Spec) er return nil } +func checkMachineConfigIsForWorker(config *anywherev1.NutanixMachineConfig, cluster *anywherev1.Cluster) error { + if config.Name == cluster.Spec.ControlPlaneConfiguration.MachineGroupRef.Name { + return fmt.Errorf("GPUs are not supported for control plane machine") + } + + if cluster.Spec.ExternalEtcdConfiguration != nil && config.Name == cluster.Spec.ExternalEtcdConfiguration.MachineGroupRef.Name { + return fmt.Errorf("GPUs are not supported for external etcd machine") + } + + for _, workerNodeGroupConfiguration := range cluster.Spec.WorkerNodeGroupConfigurations { + if config.Name == workerNodeGroupConfiguration.MachineGroupRef.Name { + return nil + } + } + + return fmt.Errorf("machine config %s is not associated with any worker node group", config.Name) +} + // findSubnetUUIDByName retrieves the subnet uuid by the given subnet name. func findSubnetUUIDByName(ctx context.Context, v3Client Client, clusterUUID, subnetName string) (*string, error) { res, err := v3Client.ListSubnet(ctx, &v3.DSMetadata{ diff --git a/pkg/providers/nutanix/validator_test.go b/pkg/providers/nutanix/validator_test.go index d0638232e9af..7ef81a156aa3 100644 --- a/pkg/providers/nutanix/validator_test.go +++ b/pkg/providers/nutanix/validator_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/yaml" "github.com/aws/eks-anywhere/internal/test" @@ -25,7 +26,6 @@ import ( "github.com/aws/eks-anywhere/pkg/constants" mockCrypto "github.com/aws/eks-anywhere/pkg/crypto/mocks" mocknutanix "github.com/aws/eks-anywhere/pkg/providers/nutanix/mocks" - "github.com/aws/eks-anywhere/pkg/utils/ptr" ) //go:embed testdata/datacenterConfig_with_trust_bundle.yaml @@ -525,7 +525,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) { mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(nil, errors.New("project not found")) machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{ Type: anywherev1.NutanixIdentifierName, - Name: ptr.String("notaproject"), + Name: utils.StringPtr("notaproject"), } clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}} return NewValidator(clientCache, validator, &http.Client{Transport: transport}) @@ -541,7 +541,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) { mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(&v3.ProjectListResponse{}, nil) machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{ Type: anywherev1.NutanixIdentifierName, - Name: ptr.String("notaproject"), + Name: utils.StringPtr("notaproject"), } clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}} return NewValidator(clientCache, validator, &http.Client{Transport: transport}) @@ -559,7 +559,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) { mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(projects, nil) machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{ Type: anywherev1.NutanixIdentifierName, - Name: ptr.String("project"), + Name: utils.StringPtr("project"), } clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}} return NewValidator(clientCache, validator, &http.Client{Transport: transport}) @@ -625,7 +625,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) { mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil) mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil) categoryKey := v3.CategoryKeyStatus{ - Name: ptr.String("key"), + Name: utils.StringPtr("key"), } mockClient.EXPECT().GetCategoryKey(gomock.Any(), gomock.Any()).Return(&categoryKey, nil) mockClient.EXPECT().GetCategoryValue(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("category value not found")) @@ -704,17 +704,112 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) { }, expectedError: "missing GPU name", }, + { + name: "machine config is not associated with any worker node group", + setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator { + mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2) + mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil) + mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil) + machineConf.Name = "test-wn" + machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{ + { + Type: "name", + Name: "NVIDIA A40-1Q", + }, + } + return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport}) + }, + expectedError: "not associated with any worker node group", + }, + { + name: "GPUs are not supported for control plane machine", + setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator { + mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2) + mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil) + mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil) + machineConf.Name = "test-cp" + machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{ + { + Type: "name", + Name: "NVIDIA A40-1Q", + }, + } + return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport}) + }, + expectedError: "GPUs are not supported for control plane machine", + }, + { + name: "GPUs are not supported for external etcd machine", + setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator { + mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2) + mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil) + mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil) + machineConf.Name = "test-etcd" + machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{ + { + Type: "name", + Name: "NVIDIA A40-1Q", + }, + } + return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport}) + }, + expectedError: "GPUs are not supported for external etcd machine", + }, + { + name: "validation pass", + setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator { + mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2) + mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil) + mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil) + machineConf.Name = "eksa-unit-test" + machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{ + { + Type: "name", + Name: "NVIDIA A40-1Q", + }, + } + return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport}) + }, + expectedError: "", + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + cluster := &anywherev1.Cluster{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: anywherev1.ClusterSpec{ + ControlPlaneConfiguration: anywherev1.ControlPlaneConfiguration{ + MachineGroupRef: &anywherev1.Ref{ + Kind: "NutanixMachineConfig", + Name: "test-cp", + }, + }, + ExternalEtcdConfiguration: &anywherev1.ExternalEtcdConfiguration{ + MachineGroupRef: &anywherev1.Ref{ + Kind: "NutanixMachineConfig", + Name: "test-etcd", + }, + }, + WorkerNodeGroupConfigurations: []anywherev1.WorkerNodeGroupConfiguration{ + { + MachineGroupRef: &anywherev1.Ref{ + Kind: "NutanixMachineConfig", + Name: "eksa-unit-test", + }, + }, + }, + }, + } machineConfig := &anywherev1.NutanixMachineConfig{} err := yaml.Unmarshal([]byte(nutanixMachineConfigSpec), machineConfig) require.NoError(t, err) mockClient := mocknutanix.NewMockClient(ctrl) validator := tc.setup(machineConfig, mockClient, mockCrypto.NewMockTlsValidator(ctrl), mocknutanix.NewMockRoundTripper(ctrl)) - err = validator.ValidateMachineConfig(context.Background(), mockClient, machineConfig) + err = validator.ValidateMachineConfig(context.Background(), mockClient, cluster, machineConfig) if tc.expectedError != "" { assert.Contains(t, err.Error(), tc.expectedError) } else { @@ -1163,7 +1258,7 @@ func TestNutanixValidatorValidateFreeGPU(t *testing.T) { clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}} return NewValidator(clientCache, validator, &http.Client{Transport: transport}) }, - expectedError: "No GPUs found", + expectedError: "no GPUs found", }, { name: "no GPU resources found: ListAllHost failed", @@ -1191,7 +1286,7 @@ func TestNutanixValidatorValidateFreeGPU(t *testing.T) { clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}} return NewValidator(clientCache, validator, &http.Client{Transport: transport}) }, - expectedError: "No GPUs found", + expectedError: "no GPUs found", }, { name: "mixed passthrough and vGPU mode GPUs in a machine config",