Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Pass K8sPluginConfig to spark driver and executor pods #patch (#271)
Browse files Browse the repository at this point in the history
* Pass default tolerations to spark driver and executor

Signed-off-by: fg91 <[email protected]>

* Test passing default tolerations to spark driver and executor

Signed-off-by: fg91 <[email protected]>

* Pass scheduler name to driver and executor SparkPodSpec

Signed-off-by: fg91 <[email protected]>

* Carry DefaultNodeSelector from k8s plugin config to SparkPodSpec

Signed-off-by: fg91 <[email protected]>

* Carry over EnableHostNetworkingPod

Signed-off-by: Fabio Grätz <[email protected]>

* Test carrying over of default env vars

Signed-off-by: Fabio Grätz <[email protected]>

* Carry over DefaultEnvVarsFromEnv

Signed-off-by: Fabio Grätz <[email protected]>

* Carry over DefaultAffinity

Signed-off-by: Fabio Grätz <[email protected]>

* Doc behaviour of default and interruptible NodeSelector and Tolerations

Signed-off-by: Fabio Grätz <[email protected]>

* Don't carry over default env vars from env and fix test

Signed-off-by: Fabio Grätz <[email protected]>

* Lint

Signed-off-by: Fabio Grätz <[email protected]>

* Apply node selector requirement to pod affinity

Signed-off-by: Fabio Grätz <[email protected]>

Signed-off-by: fg91 <[email protected]>
Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Nov 1, 2022
1 parent 932e97b commit a87ca40
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 33 deletions.
31 changes: 19 additions & 12 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ const SIGKILL = 137
const defaultContainerTemplateName = "default"
const primaryContainerTemplateName = "primary"

// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified.
func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
// ApplyInterruptibleNodeSelectorRequirement configures the node selector requirement of the node-affinity using the configuration specified.
func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1.Affinity) {
// Determine node selector terms to add to node affinity
var nodeSelectorRequirement v1.NodeSelectorRequirement
if interruptible {
Expand All @@ -42,24 +42,31 @@ func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
nodeSelectorRequirement = *config.GetK8sPluginConfig().NonInterruptibleNodeSelectorRequirement
}

if podSpec.Affinity == nil {
podSpec.Affinity = &v1.Affinity{}
if affinity.NodeAffinity == nil {
affinity.NodeAffinity = &v1.NodeAffinity{}
}
if podSpec.Affinity.NodeAffinity == nil {
podSpec.Affinity.NodeAffinity = &v1.NodeAffinity{}
if affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil {
affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{}
}
if podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil {
podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution = &v1.NodeSelector{}
}
if len(podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 {
nodeSelectorTerms := podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms
if len(affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms) > 0 {
nodeSelectorTerms := affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms
for i := range nodeSelectorTerms {
nst := &nodeSelectorTerms[i]
nst.MatchExpressions = append(nst.MatchExpressions, nodeSelectorRequirement)
}
} else {
podSpec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}}
affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}}
}

}

// ApplyInterruptibleNodeAffinity configures the node-affinity for the pod using the configuration specified.
func ApplyInterruptibleNodeAffinity(interruptible bool, podSpec *v1.PodSpec) {
if podSpec.Affinity == nil {
podSpec.Affinity = &v1.Affinity{}
}

ApplyInterruptibleNodeSelectorRequirement(interruptible, podSpec.Affinity)
}

// UpdatePod updates the base pod spec used to execute tasks. This is configured with plugins and task metadata-specific options
Expand Down
20 changes: 18 additions & 2 deletions go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
for _, envVar := range envVars {
sparkEnvVars[envVar.Name] = envVar.Value
}

sparkEnvVars["FLYTE_MAX_ATTEMPTS"] = strconv.Itoa(int(taskCtx.TaskExecutionMetadata().GetMaxAttempts()))

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
Expand All @@ -99,24 +100,34 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
}
driverSpec := sparkOp.DriverSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity,
Annotations: annotations,
Labels: labels,
EnvVars: sparkEnvVars,
Image: &container.Image,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
ServiceAccount: &serviceAccountName,
}

executorSpec := sparkOp.ExecutorSpec{
SparkPodSpec: sparkOp.SparkPodSpec{
Affinity: config.GetK8sPluginConfig().DefaultAffinity.DeepCopy(),
Annotations: annotations,
Labels: labels,
Image: &container.Image,
EnvVars: sparkEnvVars,
SecurityContenxt: config.GetK8sPluginConfig().DefaultPodSecurityContext.DeepCopy(),
DNSConfig: config.GetK8sPluginConfig().DefaultPodDNSConfig.DeepCopy(),
Tolerations: config.GetK8sPluginConfig().DefaultTolerations,
SchedulerName: &config.GetK8sPluginConfig().SchedulerName,
NodeSelector: config.GetK8sPluginConfig().DefaultNodeSelector,
HostNetwork: config.GetK8sPluginConfig().EnableHostNetworkingPod,
},
}

Expand Down Expand Up @@ -225,11 +236,16 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
j.Spec.MainClass = &sparkJob.MainClass
}

// Add Tolerations/NodeSelector to only Executor pods.
// Add Interruptible Tolerations/NodeSelector to only Executor pods.
// The Interruptible NodeSelector takes precedence over the DefaultNodeSelector
if taskCtx.TaskExecutionMetadata().IsInterruptible() {
j.Spec.Executor.Tolerations = config.GetK8sPluginConfig().InterruptibleTolerations
j.Spec.Executor.Tolerations = append(j.Spec.Executor.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...)
j.Spec.Executor.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector
}

// Add interruptible/non-interruptible node selector requirements to executor pod
flytek8s.ApplyInterruptibleNodeSelectorRequirement(taskCtx.TaskExecutionMetadata().IsInterruptible(), j.Spec.Executor.Affinity)

return j, nil
}

Expand Down
170 changes: 151 additions & 19 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spark
import (
"context"
"fmt"
"os"
"strconv"
"testing"

Expand Down Expand Up @@ -353,7 +354,67 @@ func TestBuildResourceSpark(t *testing.T) {
dnsOptVal1 := "1"
dnsOptVal2 := "1"
dnsOptVal3 := "3"

// Set scheduler
schedulerName := "custom-scheduler"

// Node selectors
defaultNodeSelector := map[string]string{
"x/default": "true",
}
interruptibleNodeSelector := map[string]string{
"x/interruptible": "true",
}

defaultPodHostNetwork := true

// Default env vars passed explicitly and default env vars derived from environment
defaultEnvVars := make(map[string]string)
defaultEnvVars["foo"] = "bar"

defaultEnvVarsFromEnv := make(map[string]string)
targetKeyFromEnv := "TEST_VAR_FROM_ENV_KEY"
targetValueFromEnv := "TEST_VAR_FROM_ENV_VALUE"
os.Setenv(targetKeyFromEnv, targetValueFromEnv)
defer os.Unsetenv(targetKeyFromEnv)
defaultEnvVarsFromEnv["fooEnv"] = targetKeyFromEnv

// Default affinity/anti-affinity
defaultAffinity := &corev1.Affinity{
NodeAffinity: &corev1.NodeAffinity{
RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{
NodeSelectorTerms: []corev1.NodeSelectorTerm{
{
MatchExpressions: []corev1.NodeSelectorRequirement{
{
Key: "x/default",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
},
},
},
},
},
},
}

// interruptible/non-interruptible nodeselector requirement
interruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{
Key: "x/interruptible",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
}

nonInterruptibleNodeSelectorRequirement := &corev1.NodeSelectorRequirement{
Key: "x/non-interruptible",
Operator: corev1.NodeSelectorOpIn,
Values: []string{"true"},
}

// NonInterruptibleNodeSelectorRequirement

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultAffinity: defaultAffinity,
DefaultPodSecurityContext: &corev1.PodSecurityContext{
RunAsUser: &runAsUser,
},
Expand All @@ -378,17 +439,31 @@ func TestBuildResourceSpark(t *testing.T) {
},
Searches: []string{"ns1.svc.cluster-domain.example", "my.dns.search.suffix"},
},
InterruptibleNodeSelector: map[string]string{
"x/interruptible": "true",
DefaultTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "default",
Operator: "Equal",
Effect: "NoSchedule",
},
},
DefaultNodeSelector: defaultNodeSelector,
InterruptibleNodeSelector: interruptibleNodeSelector,
InterruptibleTolerations: []corev1.Toleration{
{
Key: "x/flyte",
Value: "interruptible",
Operator: "Equal",
Effect: "NoSchedule",
},
}}),
},
InterruptibleNodeSelectorRequirement: interruptibleNodeSelectorRequirement,
NonInterruptibleNodeSelectorRequirement: nonInterruptibleNodeSelectorRequirement,
SchedulerName: schedulerName,
EnableHostNetworkingPod: &defaultPodHostNetwork,
DefaultEnvVars: defaultEnvVars,
DefaultEnvVarsFromEnv: defaultEnvVarsFromEnv,
}),
)
resource, err := sparkResourceHandler.BuildResource(context.TODO(), dummySparkTaskContext(taskTemplate, true))
assert.Nil(t, err)
Expand Down Expand Up @@ -438,19 +513,40 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
assert.Equal(t, dummySparkConf["spark.batchScheduler"], *sparkApp.Spec.BatchScheduler)

// Validate Interruptible Toleration and NodeSelector set for Executor but not Driver.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))

assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, schedulerName, *sparkApp.Spec.Executor.SchedulerName)
assert.Equal(t, schedulerName, *sparkApp.Spec.Driver.SchedulerName)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultPodHostNetwork, *sparkApp.Spec.Driver.HostNetwork)

// Validate
// * Interruptible Toleration and NodeSelector set for Executor but not Driver.
// * Validate Default NodeSelector set for Driver but overwritten with Interruptible NodeSelector for Executor.
// * Default Tolerations set for both Driver and Executor.
// * Interruptible/Non-Interruptible NodeSelectorRequirements set for Executor Affinity but not Driver Affinity.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
tolDriverDefault := sparkApp.Spec.Driver.Tolerations[0]
assert.Equal(t, tolDriverDefault.Key, "x/flyte")
assert.Equal(t, tolDriverDefault.Value, "default")
assert.Equal(t, tolDriverDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolDriverDefault.Effect, corev1.TaintEffect("NoSchedule"))

assert.Equal(t, 2, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))

tol := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tol.Key, "x/flyte")
assert.Equal(t, tol.Value, "interruptible")
assert.Equal(t, tol.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tol.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, interruptibleNodeSelector, sparkApp.Spec.Executor.NodeSelector)

tolExecDefault := sparkApp.Spec.Executor.Tolerations[0]
assert.Equal(t, tolExecDefault.Key, "x/flyte")
assert.Equal(t, tolExecDefault.Value, "default")
assert.Equal(t, tolExecDefault.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecDefault.Effect, corev1.TaintEffect("NoSchedule"))

tolExecInterrupt := sparkApp.Spec.Executor.Tolerations[1]
assert.Equal(t, tolExecInterrupt.Key, "x/flyte")
assert.Equal(t, tolExecInterrupt.Value, "interruptible")
assert.Equal(t, tolExecInterrupt.Operator, corev1.TolerationOperator("Equal"))
assert.Equal(t, tolExecInterrupt.Effect, corev1.TaintEffect("NoSchedule"))
assert.Equal(t, "true", sparkApp.Spec.Executor.NodeSelector["x/interruptible"])

for confKey, confVal := range dummySparkConf {
Expand Down Expand Up @@ -485,6 +581,22 @@ func TestBuildResourceSpark(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.flyteorg.feature3.enabled"], sparkApp.Spec.SparkConf["spark.flyteorg.feature3.enabled"])

assert.Equal(t, len(sparkApp.Spec.Driver.EnvVars["FLYTE_MAX_ATTEMPTS"]), 1)
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["foo"], defaultEnvVars["foo"])
assert.Equal(t, sparkApp.Spec.Driver.EnvVars["fooEnv"], targetValueFromEnv)
assert.Equal(t, sparkApp.Spec.Executor.EnvVars["fooEnv"], targetValueFromEnv)
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)

assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*interruptibleNodeSelectorRequirement,
)

// Case 2: Driver/Executor request cores set.
dummyConfWithRequest := make(map[string]string)
Expand Down Expand Up @@ -514,10 +626,30 @@ func TestBuildResourceSpark(t *testing.T) {
assert.True(t, ok)

// Validate Interruptible Toleration and NodeSelector not set for both Driver and Executors.
assert.Equal(t, 0, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 0, len(sparkApp.Spec.Executor.NodeSelector))
// Validate that the default Toleration and NodeSelector are set for both Driver and Executors.
assert.Equal(t, 1, len(sparkApp.Spec.Driver.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Driver.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Driver.NodeSelector)
assert.Equal(t, 1, len(sparkApp.Spec.Executor.Tolerations))
assert.Equal(t, 1, len(sparkApp.Spec.Executor.NodeSelector))
assert.Equal(t, defaultNodeSelector, sparkApp.Spec.Executor.NodeSelector)
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Executor.Tolerations[0].Value, "default")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Key, "x/flyte")
assert.Equal(t, sparkApp.Spec.Driver.Tolerations[0].Value, "default")

// Validate correct affinity and nodeselector requirements are set for both Driver and Executors.
assert.Equal(t, sparkApp.Spec.Driver.Affinity, defaultAffinity)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
defaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0],
)
assert.Equal(
t,
sparkApp.Spec.Executor.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[1],
*nonInterruptibleNodeSelectorRequirement,
)

// Case 4: Invalid Spark Task-Template
taskTemplate.Custom = nil
Expand Down

0 comments on commit a87ca40

Please sign in to comment.