From e3888f354142dd536b05df95b3bb4f118f87e5a7 Mon Sep 17 00:00:00 2001 From: guozhen la Date: Wed, 4 Oct 2023 14:17:30 -0400 Subject: [PATCH] init --- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 35 ++++- .../go/tasks/plugins/k8s/ray/ray_test.go | 139 +++++++++++++++--- 2 files changed, 150 insertions(+), 24 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 31a86b44dfd..592c20c3cf3 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -79,8 +79,15 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC cfg := GetConfig() headReplicas := int32(1) headNodeRayStartParams := make(map[string]string) - if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { - headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + headGroupResources := &v1.ResourceRequirements{} + if rayJob.RayCluster.HeadGroupSpec != nil{ + if rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { + headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + } + headGroupResources, err = flytek8s.ToK8sResourceRequirements(rayJob.RayCluster.HeadGroupSpec.Resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", headGroupResources, err.Error()) + } } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { headNodeRayStartParams = headNode.StartParameters } @@ -101,6 +108,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC headNodeRayStartParams[DisableUsageStatsStartParameter] = "true" } + if rayJob.RayCluster.Namespace != "" { + objectMeta.Namespace = rayJob.RayCluster.Namespace + } + enableIngress := true rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ @@ -114,7 +125,12 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx) + workerGroupResources, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", workerGroupResources, err.Error()) + } + + workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx, workerGroupResources) minReplicas := spec.Replicas maxReplicas := spec.Replicas @@ -153,7 +169,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) } - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + serviceAccountName := rayJob.RayCluster.K8SServiceAccount + if serviceAccountName == "" { + serviceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + } rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName for index := range rayClusterSpec.WorkerGroupSpecs { @@ -180,12 +199,16 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return &rayJobObject, nil } -func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 // They should always be the same, so we could hard code here. primaryContainer := container.DeepCopy() primaryContainer.Name = "ray-head" + if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { + primaryContainer.Resources = *resources + } + envs := []v1.EnvVar{ { Name: "MY_POD_IP", @@ -232,7 +255,7 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe return podTemplateSpec } -func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 // They should always be the same, so we could hard code here. initContainers := []v1.Container{ diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 5e7b82d55cb..75a1dc9d058 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -33,6 +33,8 @@ import ( const testImage = "image://" const serviceAccount = "ray_sa" +const serviceAccountOverride = "ray_sa_override" +const namespaceOverride = "ray_namespace_override" var ( dummyEnvVars = []*core.KeyValuePair{ @@ -43,6 +45,52 @@ var ( "test-args", } + headResourceOverride = core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "1000m", + }, + { + Name: core.Resources_MEMORY, + Value: "2Gi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "2000m", + }, + { + Name: core.Resources_MEMORY, + Value: "4Gi", + }, + }, + } + + workerResourceOverride = core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "5", + }, + { + Name: core.Resources_MEMORY, + Value: "10G", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "10", + }, + { + Name: core.Resources_MEMORY, + Value: "20G", + }, + }, + } + resourceRequirements = &corev1.ResourceRequirements{ Limits: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1000m"), @@ -68,6 +116,17 @@ func dummyRayCustomObj() *plugins.RayJob { } } +func dummyRayCustomObjWithOverrides() *plugins.RayJob { + return &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + K8SServiceAccount: serviceAccountOverride, + Namespace: namespaceOverride, + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}, Resources: &headResourceOverride}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, Resources: &workerResourceOverride}}, + }, + } +} + func dummyRayTaskTemplate(id string, rayJobObj *plugins.RayJob) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(rayJobObj) @@ -172,26 +231,70 @@ func TestBuildResourceRay(t *testing.T) { assert.True(t, ok) headReplica := int32(1) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, &headReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, - map[string]string{ - "dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true", - "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration) + assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas) + assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}, + ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations) workerReplica := int32(3) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, &workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, &workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, &workerReplica) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) - assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas) + assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName) + assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations) +} + +func TestBuildResourceRayWithOverrides(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObjWithOverrides()) + expectedHeadResources, _ := flytek8s.ToK8sResourceRequirements(&headResourceOverride) + expectedWorkerResources, _ := flytek8s.ToK8sResourceRequirements(&workerResourceOverride) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate)) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1alpha1.RayJob) + assert.True(t, ok) + + headReplica := int32(1) + assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta.Namespace) + assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas) + assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}, + ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations) + assert.Equal(t, *expectedHeadResources, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources) + + workerReplica := int32(3) + assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta.Namespace) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas) + assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName) + assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations) + assert.Equal(t, *expectedWorkerResources, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources) } func TestDefaultStartParameters(t *testing.T) {