From 70c23c2cf8dccd44011c2168e7f532735ce8b99c Mon Sep 17 00:00:00 2001 From: Jeev B Date: Fri, 10 Nov 2023 11:42:52 -0800 Subject: [PATCH] Add support for displaying the Ray dashboard when a RayJob is active (#4397) * Refactor task logs framework Signed-off-by: Jeev B * Return templateLogPluginCollection instead of nil even if no plugins are specified Signed-off-by: Jeev B * Add support for displaying the Ray dashboard when a RayJob is active Signed-off-by: Jeev B * Fix tasklogs returned for Ray task Signed-off-by: Jeev B * Get tasklogs working with task phase Signed-off-by: Jeev B * Misc fixes Signed-off-by: Jeev B * Add tests for dashboard URL link Signed-off-by: Jeev B * Fix linting issues and merge conflicts Signed-off-by: Jeev B --------- Signed-off-by: Jeev B --- .../go/tasks/plugins/k8s/ray/config.go | 12 +-- flyteplugins/go/tasks/plugins/k8s/ray/ray.go | 40 ++++++--- .../go/tasks/plugins/k8s/ray/ray_test.go | 84 ++++++++++++++++--- 3 files changed, 108 insertions(+), 28 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/config.go b/flyteplugins/go/tasks/plugins/k8s/ray/config.go index e123c5b8ab..8601264edf 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/config.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/config.go @@ -8,6 +8,7 @@ import ( pluginsConfig "github.com/flyteorg/flyte/flyteplugins/go/tasks/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" pluginmachinery "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -78,11 +79,12 @@ type Config struct { DeprecatedNodeIPAddress string `json:"nodeIPAddress,omitempty" pflag:"-,DEPRECATED. Please use DefaultConfig.[HeadNode|WorkerNode].IPAddress"` // Remote Ray Cluster Config - RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"` - Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"` - LogsSidecar *v1.Container `json:"logsSidecar" pflag:"-,Sidecar to inject into head pods for capturing ray job logs"` - Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"` - EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"` + RemoteClusterConfig pluginmachinery.ClusterConfig `json:"remoteClusterConfig" pflag:"Configuration of remote K8s cluster for ray jobs"` + Logs logs.LogConfig `json:"logs" pflag:"-,Log configuration for ray jobs"` + LogsSidecar *v1.Container `json:"logsSidecar" pflag:"-,Sidecar to inject into head pods for capturing ray job logs"` + DashboardURLTemplate *tasklog.TemplateLogPlugin `json:"dashboardURLTemplate" pflag:",Template for URL of Ray dashboard running on a head node."` + Defaults DefaultConfig `json:"defaults" pflag:"-,Default configuration for ray jobs"` + EnableUsageStats bool `json:"enableUsageStats" pflag:",Enable usage stats for ray jobs. These stats are submitted to usage-stats.ray.io per https://docs.ray.io/en/latest/cluster/usage-stats.html"` } type DefaultConfig struct { diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 1d0fde4ca8..0bc4f1183b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -13,6 +13,7 @@ import ( "k8s.io/client-go/kubernetes/scheme" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins" flyteerr "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/logs" @@ -437,26 +438,35 @@ func getEventInfoForRayJob(logConfig logs.LogConfig, pluginContext k8s.PluginCon return nil, fmt.Errorf("failed to initialize log plugins. Error: %w", err) } - if logPlugin == nil { - return nil, nil - } - - // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs - // RayJob CRD does not include the name of the worker or head pod for now + var taskLogs []*core.TaskLog taskExecID := pluginContext.TaskExecutionMetadata().GetTaskExecutionID() - logOutput, err := logPlugin.GetTaskLogs(tasklog.Input{ + input := tasklog.Input{ Namespace: rayJob.Namespace, TaskExecutionID: taskExecID, - }) + } + // TODO: Retrieve the name of head pod from rayJob.status, and add it to task logs + // RayJob CRD does not include the name of the worker or head pod for now + logOutput, err := logPlugin.GetTaskLogs(input) if err != nil { return nil, fmt.Errorf("failed to generate task logs. Error: %w", err) } + taskLogs = append(taskLogs, logOutput.TaskLogs...) - return &pluginsCore.TaskInfo{ - Logs: logOutput.TaskLogs, - }, nil + // Handling for Ray Dashboard + dashboardURLTemplate := GetConfig().DashboardURLTemplate + if dashboardURLTemplate != nil && + rayJob.Status.DashboardURL != "" && + rayJob.Status.JobStatus == rayv1alpha1.JobStatusRunning { + dashboardURLOutput, err := dashboardURLTemplate.GetTaskLogs(input) + if err != nil { + return nil, fmt.Errorf("failed to generate Ray dashboard link. Error: %w", err) + } + taskLogs = append(taskLogs, dashboardURLOutput.TaskLogs...) + } + + return &pluginsCore.TaskInfo{Logs: taskLogs}, nil } func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) { @@ -489,8 +499,14 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil case rayv1alpha1.JobStatusSucceeded: return pluginsCore.PhaseInfoSuccess(info), nil - case rayv1alpha1.JobStatusPending, rayv1alpha1.JobStatusRunning: + case rayv1alpha1.JobStatusPending: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil + case rayv1alpha1.JobStatusRunning: + phaseInfo := pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info) + if len(info.Logs) > 0 { + phaseInfo = phaseInfo.WithVersion(pluginsCore.DefaultPhaseVersion + 1) + } + return phaseInfo, nil case rayv1alpha1.JobStatusStopped: // There is no current usage of this job status in KubeRay. It's unclear what it represents fallthrough diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 920fa85d61..ccb518fa03 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -24,6 +24,7 @@ import ( pluginIOMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" mocks2 "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" ) @@ -615,6 +616,8 @@ func newPluginContext() k8s.PluginContext { }, }, }) + taskExecID.OnGetUniqueNodeID().Return("unique-node") + taskExecID.OnGetGeneratedName().Return("generated-name") tskCtx := &mocks.TaskExecutionMetadata{} tskCtx.OnGetTaskExecutionID().Return(taskExecID) @@ -642,17 +645,19 @@ func TestGetTaskPhase(t *testing.T) { rayJobPhase rayv1alpha1.JobStatus rayClusterPhase rayv1alpha1.JobDeploymentStatus expectedCorePhase pluginsCore.Phase + expectedError bool }{ - {"", rayv1alpha1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing}, - {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure}, - {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning}, - {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure}, - {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning}, - {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseUndefined}, - {rayv1alpha1.JobStatusRunning, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning}, - {rayv1alpha1.JobStatusFailed, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure}, - {rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess}, - {rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess}, + {"", rayv1alpha1.JobDeploymentStatusInitializing, pluginsCore.PhaseInitializing, false}, + {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetOrCreateRayCluster, pluginsCore.PhasePermanentFailure, false}, + {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusWaitForDashboard, pluginsCore.PhaseRunning, false}, + {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedJobDeploy, pluginsCore.PhasePermanentFailure, false}, + {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1alpha1.JobStatusPending, rayv1alpha1.JobDeploymentStatusFailedToGetJobStatus, pluginsCore.PhaseRunning, false}, + {rayv1alpha1.JobStatusRunning, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseRunning, false}, + {rayv1alpha1.JobStatusFailed, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhasePermanentFailure, false}, + {rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusRunning, pluginsCore.PhaseSuccess, false}, + {rayv1alpha1.JobStatusSucceeded, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseSuccess, false}, + {rayv1alpha1.JobStatusStopped, rayv1alpha1.JobDeploymentStatusComplete, pluginsCore.PhaseUndefined, true}, } for _, tc := range testCases { @@ -663,12 +668,69 @@ func TestGetTaskPhase(t *testing.T) { startTime := metav1.NewTime(time.Now()) rayObject.Status.StartTime = &startTime phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) - assert.Nil(t, err) + if tc.expectedError { + assert.Error(t, err) + } else { + assert.Nil(t, err) + } assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) }) } } +func TestGetEventInfo_DashboardURL(t *testing.T) { + pluginCtx := newPluginContext() + testCases := []struct { + name string + rayJob rayv1alpha1.RayJob + dashboardURLTemplate tasklog.TemplateLogPlugin + expectedTaskLogs []*core.TaskLog + }{ + { + name: "dashboard URL displayed", + rayJob: rayv1alpha1.RayJob{ + Status: rayv1alpha1.RayJobStatus{ + DashboardURL: "exists", + JobStatus: rayv1alpha1.JobStatusRunning, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "Ray Dashboard", + TemplateURIs: []tasklog.TemplateURI{"http://test/{{.generatedName}}"}, + Scheme: tasklog.TemplateSchemeTaskExecution, + }, + expectedTaskLogs: []*core.TaskLog{ + { + Name: "Ray Dashboard", + Uri: "http://test/generated-name", + }, + }, + }, + { + name: "dashboard URL is not displayed", + rayJob: rayv1alpha1.RayJob{ + Status: rayv1alpha1.RayJobStatus{ + JobStatus: rayv1alpha1.JobStatusPending, + }, + }, + dashboardURLTemplate: tasklog.TemplateLogPlugin{ + DisplayName: "dummy", + TemplateURIs: []tasklog.TemplateURI{"http://dummy"}, + }, + expectedTaskLogs: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.NoError(t, SetConfig(&Config{DashboardURLTemplate: &tc.dashboardURLTemplate})) + ti, err := getEventInfoForRayJob(logs.LogConfig{}, pluginCtx, &tc.rayJob) + assert.NoError(t, err) + assert.Equal(t, tc.expectedTaskLogs, ti.Logs) + }) + } +} + func TestGetPropertiesRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} expected := k8s.PluginProperties{}