Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
guozhen la committed Oct 4, 2023
1 parent a9f5d24 commit e3888f3
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 24 deletions.
35 changes: 29 additions & 6 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,15 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
cfg := GetConfig()
headReplicas := int32(1)
headNodeRayStartParams := make(map[string]string)
if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil {
headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams
headGroupResources := &v1.ResourceRequirements{}
if rayJob.RayCluster.HeadGroupSpec != nil{
if rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil {
headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams
}
headGroupResources, err = flytek8s.ToK8sResourceRequirements(rayJob.RayCluster.HeadGroupSpec.Resources)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", headGroupResources, err.Error())
}
} else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 {
headNodeRayStartParams = headNode.StartParameters
}
Expand All @@ -101,6 +108,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
headNodeRayStartParams[DisableUsageStatsStartParameter] = "true"
}

if rayJob.RayCluster.Namespace != "" {
objectMeta.Namespace = rayJob.RayCluster.Namespace
}

enableIngress := true
rayClusterSpec := rayv1alpha1.RayClusterSpec{
HeadGroupSpec: rayv1alpha1.HeadGroupSpec{
Expand All @@ -114,7 +125,12 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

for _, spec := range rayJob.RayCluster.WorkerGroupSpec {
workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx)
workerGroupResources, err := flytek8s.ToK8sResourceRequirements(spec.Resources)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", workerGroupResources, err.Error())
}

workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx, workerGroupResources)

minReplicas := spec.Replicas
maxReplicas := spec.Replicas
Expand Down Expand Up @@ -153,7 +169,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec)
}

serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
serviceAccountName := rayJob.RayCluster.K8SServiceAccount
if serviceAccountName == "" {
serviceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())
}

rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName
for index := range rayClusterSpec.WorkerGroupSpecs {
Expand All @@ -180,12 +199,16 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
return &rayJobObject, nil
}

func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97
// They should always be the same, so we could hard code here.
primaryContainer := container.DeepCopy()
primaryContainer.Name = "ray-head"

if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
primaryContainer.Resources = *resources
}

envs := []v1.EnvVar{
{
Name: "MY_POD_IP",
Expand Down Expand Up @@ -232,7 +255,7 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe
return podTemplateSpec
}

func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec {
func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec {
// Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185
// They should always be the same, so we could hard code here.
initContainers := []v1.Container{
Expand Down
139 changes: 121 additions & 18 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import (

const testImage = "image://"
const serviceAccount = "ray_sa"
const serviceAccountOverride = "ray_sa_override"
const namespaceOverride = "ray_namespace_override"

var (
dummyEnvVars = []*core.KeyValuePair{
Expand All @@ -43,6 +45,52 @@ var (
"test-args",
}

headResourceOverride = core.Resources{
Requests: []*core.Resources_ResourceEntry{
{
Name: core.Resources_CPU,
Value: "1000m",
},
{
Name: core.Resources_MEMORY,
Value: "2Gi",
},
},
Limits: []*core.Resources_ResourceEntry{
{
Name: core.Resources_CPU,
Value: "2000m",
},
{
Name: core.Resources_MEMORY,
Value: "4Gi",
},
},
}

workerResourceOverride = core.Resources{
Requests: []*core.Resources_ResourceEntry{
{
Name: core.Resources_CPU,
Value: "5",
},
{
Name: core.Resources_MEMORY,
Value: "10G",
},
},
Limits: []*core.Resources_ResourceEntry{
{
Name: core.Resources_CPU,
Value: "10",
},
{
Name: core.Resources_MEMORY,
Value: "20G",
},
},
}

resourceRequirements = &corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1000m"),
Expand All @@ -68,6 +116,17 @@ func dummyRayCustomObj() *plugins.RayJob {
}
}

func dummyRayCustomObjWithOverrides() *plugins.RayJob {
return &plugins.RayJob{
RayCluster: &plugins.RayCluster{
K8SServiceAccount: serviceAccountOverride,
Namespace: namespaceOverride,
HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}, Resources: &headResourceOverride},
WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, Resources: &workerResourceOverride}},
},
}
}

func dummyRayTaskTemplate(id string, rayJobObj *plugins.RayJob) *core.TaskTemplate {

ptObjJSON, err := utils.MarshalToString(rayJobObj)
Expand Down Expand Up @@ -172,26 +231,70 @@ func TestBuildResourceRay(t *testing.T) {
assert.True(t, ok)

headReplica := int32(1)
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, &headReplica)
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams,
map[string]string{
"dashboard-host": "0.0.0.0", "disable-usage-stats": "true", "include-dashboard": "true",
"node-ip-address": "$MY_POD_IP", "num-cpus": "1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations, map[string]string{"annotation-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels, map[string]string{"label-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations, toleration)
assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas)
assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName)
assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"},
ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams)
assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations)
assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels)
assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations)

workerReplica := int32(3)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas, &workerReplica)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas, &workerReplica)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas, &workerReplica)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName, workerGroupName)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName, serviceAccount)
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams, map[string]string{"disable-usage-stats": "true", "node-ip-address": "$MY_POD_IP"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"})
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas)
assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName)
assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName)
assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams)
assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations)
assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels)
assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations)
}

func TestBuildResourceRayWithOverrides(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}
taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObjWithOverrides())
expectedHeadResources, _ := flytek8s.ToK8sResourceRequirements(&headResourceOverride)
expectedWorkerResources, _ := flytek8s.ToK8sResourceRequirements(&workerResourceOverride)
toleration := []corev1.Toleration{{
Key: "storage",
Value: "dedicated",
Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoSchedule,
}}
err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration})
assert.Nil(t, err)

RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate))
assert.Nil(t, err)

assert.NotNil(t, RayResource)
ray, ok := RayResource.(*rayv1alpha1.RayJob)
assert.True(t, ok)

headReplica := int32(1)
assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta.Namespace)
assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas)
assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName)
assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"},
ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams)
assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations)
assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels)
assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations)
assert.Equal(t, *expectedHeadResources, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources)

workerReplica := int32(3)
assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta.Namespace)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas)
assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas)
assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName)
assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName)
assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams)
assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations)
assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels)
assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations)
assert.Equal(t, *expectedWorkerResources, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources)
}

func TestDefaultStartParameters(t *testing.T) {
Expand Down

0 comments on commit e3888f3

Please sign in to comment.