Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
adiantum committed Oct 9, 2024
1 parent 11c7a5d commit d4bcd55
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 16 deletions.
46 changes: 38 additions & 8 deletions pkg/providers/nutanix/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (v *Validator) ValidateClusterSpec(ctx context.Context, spec *cluster.Spec,
}

for _, conf := range spec.NutanixMachineConfigs {
if err := v.ValidateMachineConfig(ctx, client, conf); err != nil {
if err := v.ValidateMachineConfig(ctx, client, spec.Cluster, conf); err != nil {
return fmt.Errorf("failed to validate machine config: %v", err)
}
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func (v *Validator) validateMachineSpecs(machineSpec anywherev1.NutanixMachineCo
}

// ValidateMachineConfig validates the Prism Element cluster, subnet, and image for the machine.
func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, config *anywherev1.NutanixMachineConfig) error {
func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, cluster *anywherev1.Cluster, config *anywherev1.NutanixMachineConfig) error {
if err := v.validateMachineSpecs(config.Spec); err != nil {
return err
}
Expand Down Expand Up @@ -278,7 +278,19 @@ func (v *Validator) ValidateMachineConfig(ctx context.Context, client Client, co
}
}

if err := v.validateGPUInMachineConfig(cluster, config); err != nil {
return err
}

return nil
}

func (v *Validator) validateGPUInMachineConfig(cluster *anywherev1.Cluster, config *anywherev1.NutanixMachineConfig) error {
if config.Spec.GPUs != nil {
if err := checkMachineConfigIsForWorker(config, cluster); err != nil {
return err
}

for _, gpu := range config.Spec.GPUs {
if err := v.validateGPUConfig(gpu); err != nil {
return err
Expand Down Expand Up @@ -654,13 +666,13 @@ func createGetGpuModeFunc(gpuDeviceIDToMode map[int64]string, gpuNameToMode map[
}

func (v *Validator) validateFreeGPU(ctx context.Context, v3Client Client, cluster *cluster.Spec) error {
res, err := v3Client.ListAllHost(ctx)
if err != nil || len(res.Entities) == 0 {
return fmt.Errorf("No GPUs found: %v", err)
}

if v.isGPURequested(cluster.NutanixMachineConfigs) {
err := v.validateGPUModeNotMixed(res.Entities, cluster)
res, err := v3Client.ListAllHost(ctx)
if err != nil || len(res.Entities) == 0 {
return fmt.Errorf("no GPUs found: %v", err)
}

err = v.validateGPUModeNotMixed(res.Entities, cluster)
if err != nil {
return err
}
Expand Down Expand Up @@ -690,6 +702,24 @@ func (v *Validator) validateUpgradeRolloutStrategy(clusterSpec *cluster.Spec) er
return nil
}

func checkMachineConfigIsForWorker(config *anywherev1.NutanixMachineConfig, cluster *anywherev1.Cluster) error {
if config.Name == cluster.Spec.ControlPlaneConfiguration.MachineGroupRef.Name {
return fmt.Errorf("GPUs are not supported for control plane machine")
}

if cluster.Spec.ExternalEtcdConfiguration != nil && config.Name == cluster.Spec.ExternalEtcdConfiguration.MachineGroupRef.Name {
return fmt.Errorf("GPUs are not supported for external etcd machine")
}

for _, workerNodeGroupConfiguration := range cluster.Spec.WorkerNodeGroupConfigurations {
if config.Name == workerNodeGroupConfiguration.MachineGroupRef.Name {
return nil
}
}

return fmt.Errorf("machine config %s is not associated with any worker node group", config.Name)
}

// findSubnetUUIDByName retrieves the subnet uuid by the given subnet name.
func findSubnetUUIDByName(ctx context.Context, v3Client Client, clusterUUID, subnetName string) (*string, error) {
res, err := v3Client.ListSubnet(ctx, &v3.DSMetadata{
Expand Down
111 changes: 103 additions & 8 deletions pkg/providers/nutanix/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/yaml"

"github.com/aws/eks-anywhere/internal/test"
Expand All @@ -25,7 +26,6 @@ import (
"github.com/aws/eks-anywhere/pkg/constants"
mockCrypto "github.com/aws/eks-anywhere/pkg/crypto/mocks"
mocknutanix "github.com/aws/eks-anywhere/pkg/providers/nutanix/mocks"
"github.com/aws/eks-anywhere/pkg/utils/ptr"
)

//go:embed testdata/datacenterConfig_with_trust_bundle.yaml
Expand Down Expand Up @@ -525,7 +525,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) {
mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(nil, errors.New("project not found"))
machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{
Type: anywherev1.NutanixIdentifierName,
Name: ptr.String("notaproject"),
Name: utils.StringPtr("notaproject"),
}
clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}}
return NewValidator(clientCache, validator, &http.Client{Transport: transport})
Expand All @@ -541,7 +541,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) {
mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(&v3.ProjectListResponse{}, nil)
machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{
Type: anywherev1.NutanixIdentifierName,
Name: ptr.String("notaproject"),
Name: utils.StringPtr("notaproject"),
}
clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}}
return NewValidator(clientCache, validator, &http.Client{Transport: transport})
Expand All @@ -559,7 +559,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) {
mockClient.EXPECT().ListProject(gomock.Any(), gomock.Any()).Return(projects, nil)
machineConf.Spec.Project = &anywherev1.NutanixResourceIdentifier{
Type: anywherev1.NutanixIdentifierName,
Name: ptr.String("project"),
Name: utils.StringPtr("project"),
}
clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}}
return NewValidator(clientCache, validator, &http.Client{Transport: transport})
Expand Down Expand Up @@ -625,7 +625,7 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) {
mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil)
mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil)
categoryKey := v3.CategoryKeyStatus{
Name: ptr.String("key"),
Name: utils.StringPtr("key"),
}
mockClient.EXPECT().GetCategoryKey(gomock.Any(), gomock.Any()).Return(&categoryKey, nil)
mockClient.EXPECT().GetCategoryValue(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("category value not found"))
Expand Down Expand Up @@ -704,17 +704,112 @@ func TestNutanixValidatorValidateMachineConfig(t *testing.T) {
},
expectedError: "missing GPU name",
},
{
name: "machine config is not associated with any worker node group",
setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator {
mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2)
mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil)
mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil)
machineConf.Name = "test-wn"
machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{
{
Type: "name",
Name: "NVIDIA A40-1Q",
},
}
return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport})
},
expectedError: "not associated with any worker node group",
},
{
name: "GPUs are not supported for control plane machine",
setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator {
mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2)
mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil)
mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil)
machineConf.Name = "test-cp"
machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{
{
Type: "name",
Name: "NVIDIA A40-1Q",
},
}
return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport})
},
expectedError: "GPUs are not supported for control plane machine",
},
{
name: "GPUs are not supported for external etcd machine",
setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator {
mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2)
mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil)
mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil)
machineConf.Name = "test-etcd"
machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{
{
Type: "name",
Name: "NVIDIA A40-1Q",
},
}
return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport})
},
expectedError: "GPUs are not supported for external etcd machine",
},
{
name: "validation pass",
setup: func(machineConf *anywherev1.NutanixMachineConfig, mockClient *mocknutanix.MockClient, validator *mockCrypto.MockTlsValidator, transport *mocknutanix.MockRoundTripper) *Validator {
mockClient.EXPECT().ListCluster(gomock.Any(), gomock.Any()).Return(fakeClusterList(), nil).Times(2)
mockClient.EXPECT().ListSubnet(gomock.Any(), gomock.Any()).Return(fakeSubnetList(), nil)
mockClient.EXPECT().ListImage(gomock.Any(), gomock.Any()).Return(fakeImageList(), nil)
machineConf.Name = "eksa-unit-test"
machineConf.Spec.GPUs = []anywherev1.NutanixGPUIdentifier{
{
Type: "name",
Name: "NVIDIA A40-1Q",
},
}
return NewValidator(&ClientCache{}, validator, &http.Client{Transport: transport})
},
expectedError: "",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster := &anywherev1.Cluster{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: anywherev1.ClusterSpec{
ControlPlaneConfiguration: anywherev1.ControlPlaneConfiguration{
MachineGroupRef: &anywherev1.Ref{
Kind: "NutanixMachineConfig",
Name: "test-cp",
},
},
ExternalEtcdConfiguration: &anywherev1.ExternalEtcdConfiguration{
MachineGroupRef: &anywherev1.Ref{
Kind: "NutanixMachineConfig",
Name: "test-etcd",
},
},
WorkerNodeGroupConfigurations: []anywherev1.WorkerNodeGroupConfiguration{
{
MachineGroupRef: &anywherev1.Ref{
Kind: "NutanixMachineConfig",
Name: "eksa-unit-test",
},
},
},
},
}
machineConfig := &anywherev1.NutanixMachineConfig{}
err := yaml.Unmarshal([]byte(nutanixMachineConfigSpec), machineConfig)
require.NoError(t, err)

mockClient := mocknutanix.NewMockClient(ctrl)
validator := tc.setup(machineConfig, mockClient, mockCrypto.NewMockTlsValidator(ctrl), mocknutanix.NewMockRoundTripper(ctrl))
err = validator.ValidateMachineConfig(context.Background(), mockClient, machineConfig)
err = validator.ValidateMachineConfig(context.Background(), mockClient, cluster, machineConfig)
if tc.expectedError != "" {
assert.Contains(t, err.Error(), tc.expectedError)
} else {
Expand Down Expand Up @@ -1163,7 +1258,7 @@ func TestNutanixValidatorValidateFreeGPU(t *testing.T) {
clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}}
return NewValidator(clientCache, validator, &http.Client{Transport: transport})
},
expectedError: "No GPUs found",
expectedError: "no GPUs found",
},
{
name: "no GPU resources found: ListAllHost failed",
Expand Down Expand Up @@ -1191,7 +1286,7 @@ func TestNutanixValidatorValidateFreeGPU(t *testing.T) {
clientCache := &ClientCache{clients: map[string]Client{"test": mockClient}}
return NewValidator(clientCache, validator, &http.Client{Transport: transport})
},
expectedError: "No GPUs found",
expectedError: "no GPUs found",
},
{
name: "mixed passthrough and vGPU mode GPUs in a machine config",
Expand Down

0 comments on commit d4bcd55

Please sign in to comment.