diff --git a/go.mod b/go.mod index e5863add0..6a2331949 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v1.3.16 + github.com/flyteorg/flyteidl v1.3.19 github.com/flyteorg/flytestdlib v1.0.15 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.2 diff --git a/go.sum b/go.sum index 09a4ffa14..70bbe278f 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.3.16 h1:mRq1VeUl5LP12dezbGHLQcrLuAmO9kawK9X7arqCInM= -github.com/flyteorg/flyteidl v1.3.16/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y= +github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 71550c4bc..338f6cd56 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -73,26 +73,59 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, fmt.Errorf("number of worker should be more then 0") } - jobSpec := kubeflowv1.PyTorchJobSpec{ - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeMaster: { - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, + var jobSpec kubeflowv1.PyTorchJobSpec + + elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() + + if elasticConfig != nil { + minReplicas := elasticConfig.GetMinReplicas() + maxReplicas := elasticConfig.GetMaxReplicas() + nProcPerNode := elasticConfig.GetNprocPerNode() + maxRestarts := elasticConfig.GetMaxRestarts() + rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) + + jobSpec = kubeflowv1.PyTorchJobSpec{ + ElasticPolicy: &kubeflowv1.ElasticPolicy{ + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + RDZVBackend: &rdzvBackend, + NProcPerNode: &nProcPerNode, + MaxRestarts: &maxRestarts, + }, + PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, }, - RestartPolicy: commonOp.RestartPolicyNever, }, - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, + } + + } else { + + jobSpec = kubeflowv1.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeMaster: { + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + kubeflowv1.PyTorchJobReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, }, - RestartPolicy: commonOp.RestartPolicyNever, }, - }, + } } - job := &kubeflowv1.PyTorchJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.PytorchJobKind, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index f1979dbf2..150bdb59a 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -69,6 +69,13 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas } } +func dummyElasticPytorchCustomObj(workers int32, elasticConfig plugins.ElasticConfig) *plugins.DistributedPyTorchTrainingTask { + return &plugins.DistributedPyTorchTrainingTask{ + Workers: workers, + ElasticConfig: &elasticConfig, + } +} + func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) @@ -260,7 +267,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl } ptObj := dummyPytorchCustomObj(workers) - taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("job1", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) if err != nil { panic(err) @@ -282,11 +289,11 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl } } -func TestBuildResourcePytorch(t *testing.T) { +func TestBuildResourcePytorchElastic(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} - ptObj := dummyPytorchCustomObj(100) - taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) + ptObj := dummyElasticPytorchCustomObj(2, plugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}) + taskTemplate := dummyPytorchTaskTemplate("job2", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) assert.NoError(t, err) @@ -294,7 +301,41 @@ func TestBuildResourcePytorch(t *testing.T) { pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) assert.True(t, ok) + assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.NotNil(t, pytorchJob.Spec.ElasticPolicy) + assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas) + assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas) + assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode) + assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend) + + assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + + var hasContainerWithDefaultPytorchName = false + + for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers { + if container.Name == kubeflowv1.PytorchJobDefaultContainerName { + hasContainerWithDefaultPytorchName = true + } + } + + assert.True(t, hasContainerWithDefaultPytorchName) +} + +func TestBuildResourcePytorch(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + ptObj := dummyPytorchCustomObj(100) + taskTemplate := dummyPytorchTaskTemplate("job3", ptObj) + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) for _, replicaSpec := range pytorchJob.Spec.PyTorchReplicaSpecs { var hasContainerWithDefaultPytorchName = false @@ -392,17 +433,17 @@ func TestReplicaCounts(t *testing.T) { ptObj := dummyPytorchCustomObj(test.workerReplicaCount) taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) if test.expectError { assert.Error(t, err) - assert.Nil(t, resource) + assert.Nil(t, res) return } assert.NoError(t, err) - assert.NotNil(t, resource) + assert.NotNil(t, res) - job, ok := resource.(*kubeflowv1.PyTorchJob) + job, ok := res.(*kubeflowv1.PyTorchJob) assert.True(t, ok) assert.Len(t, job.Spec.PyTorchReplicaSpecs, len(test.contains))