diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go index eb19015586..adb2d655bb 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/config.go @@ -64,8 +64,9 @@ var ( DefaultPodTemplateResync: config2.Duration{ Duration: 30 * time.Second, }, - UpdateBaseBackoffDuration: 10, - UpdateBackoffRetries: 5, + UpdateBaseBackoffDuration: 10, + UpdateBackoffRetries: 5, + AddTolerationsForExtendedResources: []string{}, } // K8sPluginConfigSection provides a singular top level config section for all plugins. @@ -214,6 +215,9 @@ type K8sPluginConfig struct { // Number of retries for exponential backoff when updating a resource. UpdateBackoffRetries int `json:"update-backoff-retries" pflag:",Number of retries for exponential backoff when updating a resource."` + + // Extended resources that should be added to the tolerations automatically. + AddTolerationsForExtendedResources []string `json:"add-tolerations-for-extended-resources" pflag:",Name of the extended resources for which tolerations should be added."` } // FlyteCoPilotConfig specifies configuration for the Flyte CoPilot system. FlyteCoPilot, allows running flytekit-less containers diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go index 4652d0bfd4..caa485ff39 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags.go @@ -69,5 +69,6 @@ func (cfg K8sPluginConfig) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "send-object-events"), defaultK8sConfig.SendObjectEvents, "If true, will send k8s object events in TaskExecutionEvent updates.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "update-base-backoff-duration"), defaultK8sConfig.UpdateBaseBackoffDuration, "Initial delay in exponential backoff when updating a resource in milliseconds.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "update-backoff-retries"), defaultK8sConfig.UpdateBackoffRetries, "Number of retries for exponential backoff when updating a resource.") + cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "add-tolerations-for-extended-resources"), defaultK8sConfig.AddTolerationsForExtendedResources, "Name of the extended resources for which tolerations should be added.") return cmdFlags } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go index cc46ffa466..cb50078620 100755 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/config/k8spluginconfig_flags_test.go @@ -365,4 +365,18 @@ func TestK8sPluginConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_add-tolerations-for-extended-resources", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := join_K8sPluginConfig(defaultK8sConfig.AddTolerationsForExtendedResources, ",") + + cmdFlags.Set("add-tolerations-for-extended-resources", testValue) + if vStringSlice, err := cmdFlags.GetStringSlice("add-tolerations-for-extended-resources"); err == nil { + testDecodeRaw_K8sPluginConfig(t, join_K8sPluginConfig(vStringSlice, ","), &actual.AddTolerationsForExtendedResources) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 53acac5512..6beca78f54 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -11,6 +11,7 @@ import ( "github.com/imdario/mergo" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/sets" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginserrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" @@ -445,6 +446,54 @@ func ApplyContainerImageOverride(podSpec *v1.PodSpec, containerImage string, pri } } +func addTolerationInPodSpec(podSpec *v1.PodSpec, toleration *v1.Toleration) *v1.PodSpec { + podTolerations := podSpec.Tolerations + + var newTolerations []v1.Toleration + for i := range podTolerations { + if toleration.MatchToleration(&podTolerations[i]) { + return podSpec + } + newTolerations = append(newTolerations, podTolerations[i]) + } + newTolerations = append(newTolerations, *toleration) + podSpec.Tolerations = newTolerations + return podSpec +} + +func AddTolerationsForExtendedResources(podSpec *v1.PodSpec) *v1.PodSpec { + if podSpec == nil { + podSpec = &v1.PodSpec{} + } + + resources := sets.NewString() + for _, container := range podSpec.Containers { + for _, extendedResource := range config.GetK8sPluginConfig().AddTolerationsForExtendedResources { + if _, ok := container.Resources.Requests[v1.ResourceName(extendedResource)]; ok { + resources.Insert(extendedResource) + } + } + } + + for _, container := range podSpec.InitContainers { + for _, extendedResource := range config.GetK8sPluginConfig().AddTolerationsForExtendedResources { + if _, ok := container.Resources.Requests[v1.ResourceName(extendedResource)]; ok { + resources.Insert(extendedResource) + } + } + } + + for _, resource := range resources.List() { + addTolerationInPodSpec(podSpec, &v1.Toleration{ + Key: resource, + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }) + } + + return podSpec +} + // ToK8sPodSpec builds a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This // involves parsing the raw PodSpec definition and applying all Flyte configuration options. func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, string, error) { @@ -460,6 +509,8 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (* return nil, nil, "", err } + podSpec = AddTolerationsForExtendedResources(podSpec) + return podSpec, objectMeta, primaryContainerName, nil } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 65194d01be..0a70cdd895 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -3,6 +3,7 @@ package flytek8s import ( "context" "encoding/json" + "fmt" "io/ioutil" "path/filepath" "reflect" @@ -2244,3 +2245,112 @@ func TestAddFlyteCustomizationsToContainer_SetConsoleUrl(t *testing.T) { }) } } + +func TestAddTolerationsForExtendedResources(t *testing.T) { + gpuResourceName := v1.ResourceName("nvidia.com/gpu") + addTolerationResourceName := v1.ResourceName("foo/bar") + noTolerationResourceName := v1.ResourceName("foo/baz") + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + GpuResourceName: gpuResourceName, + AddTolerationsForExtendedResources: []string{ + gpuResourceName.String(), + addTolerationResourceName.String(), + }, + })) + + podSpec := &v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + gpuResourceName: resource.MustParse("1"), + addTolerationResourceName: resource.MustParse("1"), + noTolerationResourceName: resource.MustParse("1"), + }, + }, + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "foo", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, + }, + } + + podSpec = AddTolerationsForExtendedResources(podSpec) + fmt.Printf("%v\n", podSpec.Tolerations) + assert.Equal(t, 3, len(podSpec.Tolerations)) + assert.Equal(t, addTolerationResourceName.String(), podSpec.Tolerations[1].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[1].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[1].Effect) + assert.Equal(t, gpuResourceName.String(), podSpec.Tolerations[2].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[2].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[2].Effect) + + podSpec = &v1.PodSpec{ + InitContainers: []v1.Container{ + v1.Container{ + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + gpuResourceName: resource.MustParse("1"), + addTolerationResourceName: resource.MustParse("1"), + noTolerationResourceName: resource.MustParse("1"), + }, + }, + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "foo", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, + }, + } + + podSpec = AddTolerationsForExtendedResources(podSpec) + assert.Equal(t, 3, len(podSpec.Tolerations)) + assert.Equal(t, addTolerationResourceName.String(), podSpec.Tolerations[1].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[1].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[1].Effect) + assert.Equal(t, gpuResourceName.String(), podSpec.Tolerations[2].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[2].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[2].Effect) + + podSpec = &v1.PodSpec{ + Containers: []v1.Container{ + v1.Container{ + Resources: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + gpuResourceName: resource.MustParse("1"), + addTolerationResourceName: resource.MustParse("1"), + noTolerationResourceName: resource.MustParse("1"), + }, + }, + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "foo", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, + { + Key: gpuResourceName.String(), + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, + }, + } + + podSpec = AddTolerationsForExtendedResources(podSpec) + assert.Equal(t, 3, len(podSpec.Tolerations)) + assert.Equal(t, gpuResourceName.String(), podSpec.Tolerations[1].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[1].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[1].Effect) + assert.Equal(t, addTolerationResourceName.String(), podSpec.Tolerations[2].Key) + assert.Equal(t, v1.TolerationOpExists, podSpec.Tolerations[2].Operator) + assert.Equal(t, v1.TaintEffectNoSchedule, podSpec.Tolerations[2].Effect) +} diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index eba53067ef..bc8b4adef4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -525,9 +525,10 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) { func TestBuildResourceDaskExtendedResources(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + AddTolerationsForExtendedResources: []string{"nvidia.com/gpu"}, })) fixtures := []struct { @@ -569,6 +570,11 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) { Operator: v1.TolerationOpEqual, Effect: v1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, }, }, { @@ -620,6 +626,11 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) { Operator: v1.TolerationOpEqual, Effect: v1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + }, }, }, } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 2bc668f2c6..346b34adb6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -368,9 +368,10 @@ func TestBuildResourceMPIForWrongInput(t *testing.T) { func TestBuildResourceMPIExtendedResources(t *testing.T) { assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + AddTolerationsForExtendedResources: []string{"nvidia.com/gpu"}, })) fixtures := []struct { @@ -412,6 +413,11 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, { @@ -463,6 +469,11 @@ func TestBuildResourceMPIExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, } diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 741a8a00c8..0f38bb2851 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -473,9 +473,10 @@ func TestBuildResourcePytorchContainerImage(t *testing.T) { func TestBuildResourcePytorchExtendedResources(t *testing.T) { assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + AddTolerationsForExtendedResources: []string{"nvidia.com/gpu"}, })) fixtures := []struct { @@ -517,6 +518,11 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, { @@ -568,6 +574,11 @@ func TestBuildResourcePytorchExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, } diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index c1e5dd264d..d264b69c9e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -374,6 +374,8 @@ func buildHeadPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodSpe return v1.PodTemplateSpec{}, err } + basePodSpec = flytek8s.AddTolerationsForExtendedResources(basePodSpec) + podTemplateSpec := v1.PodTemplateSpec{ Spec: *basePodSpec, ObjectMeta: *objectMeta, @@ -502,6 +504,8 @@ func buildWorkerPodTemplate(primaryContainer *v1.Container, basePodSpec *v1.PodS return v1.PodTemplateSpec{}, err } + basePodSpec = flytek8s.AddTolerationsForExtendedResources(basePodSpec) + podTemplateSpec := v1.PodTemplateSpec{ Spec: *basePodSpec, ObjectMeta: *objectMetadata, diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 708939485b..40aa509f54 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -280,9 +280,10 @@ func TestBuildResourceRayContainerImage(t *testing.T) { func TestBuildResourceRayExtendedResources(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ - GpuDeviceNodeLabel: "gpu-node-label", - GpuPartitionSizeNodeLabel: "gpu-partition-size", - GpuResourceName: flytek8s.ResourceNvidiaGPU, + GpuDeviceNodeLabel: "gpu-node-label", + GpuPartitionSizeNodeLabel: "gpu-partition-size", + GpuResourceName: flytek8s.ResourceNvidiaGPU, + AddTolerationsForExtendedResources: []string{"nvidia.com/gpu"}, })) params := []struct { @@ -324,6 +325,11 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, { @@ -375,6 +381,11 @@ func TestBuildResourceRayExtendedResources(t *testing.T) { Operator: corev1.TolerationOpEqual, Effect: corev1.TaintEffectNoSchedule, }, + { + Key: "nvidia.com/gpu", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }, }, }, }