diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index c0020b642..82f482285 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -26,8 +26,8 @@ const SIGKILL = 137 const defaultContainerTemplateName = "default" const primaryContainerTemplateName = "primary" -// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified. -func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) { +// ApplyInterruptibleNodeSelectorRequirement configures the node selector requirement of the node-affinity using the configuration specified. +func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1.Affinity) { // Determine node selector terms to add to node affinity var nodeSelectorRequirement v1.NodeSelectorRequirement if interruptible { @@ -42,24 +42,31 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) { nodeSelectorRequirement = *config.GetK8sPluginConfig().NonInterruptibleNodeSelectorRequirement } - if podSpec.Affinity == nil { - podSpec.Affinity = &v1.Affinity{} + if affinity.NodeAffinity == nil { + affinity.NodeAffinity = &v1.NodeAffinity{} } - if podSpec.Affinity.NodeAffinity == nil { - podSpec.Affinity.NodeAffinity = &v1.NodeAffinity{} + if affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil { + affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{} } - if podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil { - podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{} - } - if len(podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 { - nodeSelectorTerms := podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms + if len(affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 { + nodeSelectorTerms := affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms for i := range nodeSelectorTerms { nst := &nodeSelectorTerms[i] nst.MatchExpressions = append(nst.MatchExpressions, nodeSelectorRequirement) } } else { - podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} + affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} + } + +} + +// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified. +func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) { + if podSpec.Affinity == nil { + podSpec.Affinity = &v1.Affinity{} } + + ApplyInterruptibleNodeSelectorRequirement(interruptible, podSpec.Affinity) } // UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 59dcb9aaf..c6ce78c1e 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -90,6 +90,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo for _, envVar := range envVars { sparkEnvVars[envVar.Name] = envVar.Value } + sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts())) serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) @@ -99,24 +100,34 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } driverSpec := sparkOp.DriverSpec{ SparkPodSpec: sparkOp.SparkPodSpec{ + Affinity: config.GetK8sPluginConfig().DefaultAffinity, Annotations: annotations, Labels: labels, EnvVars: sparkEnvVars, Image: &container.Image, SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), + Tolerations: config.GetK8sPluginConfig().DefaultTolerations, + SchedulerName: &config.GetK8sPluginConfig().SchedulerName, + NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, + HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, }, ServiceAccount: &serviceAccountName, } executorSpec := sparkOp.ExecutorSpec{ SparkPodSpec: sparkOp.SparkPodSpec{ + Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(), Annotations: annotations, Labels: labels, Image: &container.Image, EnvVars: sparkEnvVars, SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(), DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(), + Tolerations: config.GetK8sPluginConfig().DefaultTolerations, + SchedulerName: &config.GetK8sPluginConfig().SchedulerName, + NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector, + HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod, }, } @@ -225,11 +236,16 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo j.Spec.MainClass = &sparkJob.MainClass } - // Add Tolerations/NodeSelector to only Executor pods. + // Add Interruptible Tolerations/NodeSelector to only Executor pods. + // The Interruptible NodeSelector takes precedence over the DefaultNodeSelector if taskCtx.TaskExecutionMetadata().IsInterruptible() { - j.Spec.Executor.Tolerations = config.GetK8sPluginConfig().InterruptibleTolerations + j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector } + + // Add interruptible/non-interruptible node selector requirements to executor pod + flytek8s.ApplyInterruptibleNodeSelectorRequirement(taskCtx.TaskExecutionMetadata().IsInterruptible(), j.Spec.Executor.Affinity) + return j, nil } diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 13721aee6..5ac764910 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -3,6 +3,7 @@ package spark import ( "context" "fmt" + "os" "strconv" "testing" @@ -353,7 +354,67 @@ func TestBuildResourceSpark(t *testing.T) { dnsOptVal1 := "1" dnsOptVal2 := "1" dnsOptVal3 := "3" + + // Set scheduler + schedulerName := "custom-scheduler" + + // Node selectors + defaultNodeSelector := map[string]string{ + "x/default": "true", + } + interruptibleNodeSelector := map[string]string{ + "x/interruptible": "true", + } + + defaultPodHostNetwork := true + + // Default env vars passed explicitly and default env vars derived from environment + defaultEnvVars := make(map[string]string) + defaultEnvVars["foo"] = "bar" + + defaultEnvVarsFromEnv := make(map[string]string) + targetKeyFromEnv := "TEST_VAR_FROM_ENV_KEY" + targetValueFromEnv := "TEST_VAR_FROM_ENV_VALUE" + os.Setenv(targetKeyFromEnv, targetValueFromEnv) + defer os.Unsetenv(targetKeyFromEnv) + defaultEnvVarsFromEnv["fooEnv"] = targetKeyFromEnv + + // Default affinity/anti-affinity + defaultAffinity := &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "x/default", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"true"}, + }, + }, + }, + }, + }, + }, + } + + // interruptible/non-interruptible nodeselector requirement + interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{ + Key: "x/interruptible", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"true"}, + } + + nonInterruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{ + Key: "x/non-interruptible", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"true"}, + } + + // NonInterruptibleNodeSelectorRequirement + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + DefaultAffinity: defaultAffinity, DefaultPodSecurityContext: &corev1.PodSecurityContext{ RunAsUser: &runAsUser, }, @@ -378,9 +439,16 @@ func TestBuildResourceSpark(t *testing.T) { }, Searches: []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"}, }, - InterruptibleNodeSelector: map[string]string{ - "x/interruptible": "true", + DefaultTolerations: []corev1.Toleration{ + { + Key: "x/flyte", + Value: "default", + Operator: "Equal", + Effect: "NoSchedule", + }, }, + DefaultNodeSelector: defaultNodeSelector, + InterruptibleNodeSelector: interruptibleNodeSelector, InterruptibleTolerations: []corev1.Toleration{ { Key: "x/flyte", @@ -388,7 +456,14 @@ func TestBuildResourceSpark(t *testing.T) { Operator: "Equal", Effect: "NoSchedule", }, - }}), + }, + InterruptibleNodeSelectorRequirement: interruptibleNodeSelectorRequirement, + NonInterruptibleNodeSelectorRequirement: nonInterruptibleNodeSelectorRequirement, + SchedulerName: schedulerName, + EnableHostNetworkingPod: &defaultPodHostNetwork, + DefaultEnvVars: defaultEnvVars, + DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv, + }), ) resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true)) assert.Nil(t, err) @@ -438,19 +513,40 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler) - - // Validate Interruptible Toleration and NodeSelector set for Executor but not Driver. - assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations)) - assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector)) - - assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations)) + assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork) + + // Validate + // * Interruptible Toleration and NodeSelector set for Executor but not Driver. + // * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor. + // * Default Tolerations set for both Driver and Executor. + // * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity. + assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) + assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) + assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0] + assert.Equal(t, tolDriverDefault.Key, "x/flyte") + assert.Equal(t, tolDriverDefault.Value, "default") + assert.Equal(t, tolDriverDefault.Operator, corev1.TolerationOperator("Equal")) + assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule")) + + assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations)) assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) - - tol := sparkApp.Spec.Executor.Tolerations[0] - assert.Equal(t, tol.Key, "x/flyte") - assert.Equal(t, tol.Value, "interruptible") - assert.Equal(t, tol.Operator, corev1.TolerationOperator("Equal")) - assert.Equal(t, tol.Effect, corev1.TaintEffect("NoSchedule")) + assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector) + + tolExecDefault := sparkApp.Spec.Executor.Tolerations[0] + assert.Equal(t, tolExecDefault.Key, "x/flyte") + assert.Equal(t, tolExecDefault.Value, "default") + assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal")) + assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule")) + + tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1] + assert.Equal(t, tolExecInterrupt.Key, "x/flyte") + assert.Equal(t, tolExecInterrupt.Value, "interruptible") + assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal")) + assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule")) assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"]) for confKey, confVal := range dummySparkConf { @@ -485,6 +581,22 @@ func TestBuildResourceSpark(t *testing.T) { assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"]) assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1) + assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"]) + assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"]) + assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv) + assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv) + assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity) + + assert.Equal( + t, + sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + ) + assert.Equal( + t, + sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], + *interruptibleNodeSelectorRequirement, + ) // Case 2: Driver/Executor request cores set. dummyConfWithRequest := make(map[string]string) @@ -514,10 +626,30 @@ func TestBuildResourceSpark(t *testing.T) { assert.True(t, ok) // Validate Interruptible Toleration and NodeSelector not set for both Driver and Executors. - assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations)) - assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector)) - assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations)) - assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector)) + // Validate that the default Toleration and NodeSelector are set for both Driver and Executors. + assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations)) + assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector)) + assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations)) + assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector)) + assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte") + assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default") + assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte") + assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default") + + // Validate correct affinity and nodeselector requirements are set for both Driver and Executors. + assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity) + assert.Equal( + t, + sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + ) + assert.Equal( + t, + sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1], + *nonInterruptibleNodeSelectorRequirement, + ) // Case 4: Invalid Spark Task-Template taskTemplate.Custom = nil