Skip to content

Commit

Permalink
Update ray task to use latest updates in flyteidl
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced committed Nov 13, 2024
1 parent a6d6335 commit c621359
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 5 deletions.
176 changes: 176 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,170 @@
import json as _json
import typing

from flyteidl.core import tasks_pb2 as _core_task
from flyteidl.plugins import ray_pb2 as _ray_pb2
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct

from flytekit.models import common as _common


class K8sObjectMetadata(_common.FlyteIdlEntity):
def __init__(self, labels: typing.Dict[str, str] = None, annotations: typing.Dict[str, str] = None):
"""
This defines additional metadata for building a kubernetes pod.
"""
self._labels = labels
self._annotations = annotations

@property
def labels(self) -> typing.Dict[str, str]:
return self._labels

@property
def annotations(self) -> typing.Dict[str, str]:
return self._annotations

def to_flyte_idl(self) -> _core_task.K8sObjectMetadata:
return _core_task.K8sObjectMetadata(
labels={k: v for k, v in self.labels.items()} if self.labels is not None else None,
annotations={k: v for k, v in self.annotations.items()} if self.annotations is not None else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _core_task.K8sObjectMetadata):
return cls(
labels={k: v for k, v in pb2_object.labels.items()} if pb2_object.labels is not None else None,
annotations={k: v for k, v in pb2_object.annotations.items()}
if pb2_object.annotations is not None
else None,
)


class IOStrategy(_common.FlyteIdlEntity):
"""
Provides methods to manage data in and out of the Raw container using Download Modes. This can only be used if DataLoadingConfig is enabled.
"""

DOWNLOAD_MODE_EAGER = _core_task.IOStrategy.DOWNLOAD_EAGER
DOWNLOAD_MODE_STREAM = _core_task.IOStrategy.DOWNLOAD_STREAM
DOWNLOAD_MODE_NO_DOWNLOAD = _core_task.IOStrategy.DO_NOT_DOWNLOAD

UPLOAD_MODE_EAGER = _core_task.IOStrategy.UPLOAD_EAGER
UPLOAD_MODE_ON_EXIT = _core_task.IOStrategy.UPLOAD_ON_EXIT
UPLOAD_MODE_NO_UPLOAD = _core_task.IOStrategy.DO_NOT_UPLOAD

def __init__(
self,
download_mode: _core_task.IOStrategy.DownloadMode = DOWNLOAD_MODE_EAGER,
upload_mode: _core_task.IOStrategy.UploadMode = UPLOAD_MODE_ON_EXIT,
):
self._download_mode = download_mode
self._upload_mode = upload_mode

def to_flyte_idl(self) -> _core_task.IOStrategy:
return _core_task.IOStrategy(download_mode=self._download_mode, upload_mode=self._upload_mode)

@classmethod
def from_flyte_idl(cls, pb2_object: _core_task.IOStrategy):
if pb2_object is None:
return None
return cls(
download_mode=pb2_object.download_mode,
upload_mode=pb2_object.upload_mode,
)


class DataLoadingConfig(_common.FlyteIdlEntity):
LITERALMAP_FORMAT_PROTO = _core_task.DataLoadingConfig.PROTO
LITERALMAP_FORMAT_JSON = _core_task.DataLoadingConfig.JSON
LITERALMAP_FORMAT_YAML = _core_task.DataLoadingConfig.YAML
_LITERALMAP_FORMATS = frozenset([LITERALMAP_FORMAT_JSON, LITERALMAP_FORMAT_PROTO, LITERALMAP_FORMAT_YAML])

def __init__(
self,
input_path: str,
output_path: str,
enabled: bool = True,
format: _core_task.DataLoadingConfig.LiteralMapFormat = LITERALMAP_FORMAT_PROTO,
io_strategy: IOStrategy = None,
):
if format not in self._LITERALMAP_FORMATS:
raise ValueError(
"Metadata format {} not supported. Should be one of {}".format(format, self._LITERALMAP_FORMATS)
)
self._input_path = input_path
self._output_path = output_path
self._enabled = enabled
self._format = format
self._io_strategy = io_strategy

def to_flyte_idl(self) -> _core_task.DataLoadingConfig:
return _core_task.DataLoadingConfig(
input_path=self._input_path,
output_path=self._output_path,
format=self._format,
enabled=self._enabled,
io_strategy=self._io_strategy.to_flyte_idl() if self._io_strategy is not None else None,
)

@classmethod
def from_flyte_idl(cls, pb2: _core_task.DataLoadingConfig) -> "DataLoadingConfig":
if pb2 is None:
return None
return cls(
input_path=pb2.input_path,
output_path=pb2.output_path,
enabled=pb2.enabled,
format=pb2.format,
io_strategy=IOStrategy.from_flyte_idl(pb2.io_strategy) if pb2.HasField("io_strategy") else None,
)


class K8sPod(_common.FlyteIdlEntity):
def __init__(
self,
metadata: K8sObjectMetadata = None,
pod_spec: typing.Dict[str, typing.Any] = None,
data_config: typing.Optional[DataLoadingConfig] = None,
):
"""
This defines a kubernetes pod target. It will build the pod target during task execution
"""
self._metadata = metadata
self._pod_spec = pod_spec
self._data_config = data_config

@property
def metadata(self) -> K8sObjectMetadata:
return self._metadata

@property
def pod_spec(self) -> typing.Dict[str, typing.Any]:
return self._pod_spec

@property
def data_config(self) -> typing.Optional[DataLoadingConfig]:
return self._data_config

def to_flyte_idl(self) -> _core_task.K8sPod:
return _core_task.K8sPod(
metadata=self._metadata.to_flyte_idl() if self._metadata else None,
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _core_task.K8sPod):
return cls(
metadata=K8sObjectMetadata.from_flyte_idl(pb2_object.metadata) if pb2_object.HasField("metadata") else None,
pod_spec=_json_format.MessageToDict(pb2_object.pod_spec) if pb2_object.HasField("pod_spec") else None,
data_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config)
if pb2_object.HasField("data_config")
else None,
)


class WorkerGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
Expand All @@ -13,12 +173,14 @@ def __init__(
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
):
self._group_name = group_name
self._replicas = replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod

@property
def group_name(self):
Expand Down Expand Up @@ -60,6 +222,10 @@ def ray_start_params(self):
"""
return self._ray_start_params

@property
def k8s_pod(self):
return self._k8s_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec
Expand All @@ -70,6 +236,7 @@ def to_flyte_idl(self):
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
ray_start_params=self.ray_start_params,
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
)

@classmethod
Expand All @@ -84,15 +251,18 @@ def from_flyte_idl(cls, proto):
min_replicas=proto.min_replicas,
max_replicas=proto.max_replicas,
ray_start_params=proto.ray_start_params,
k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None,
)


class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
):
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod

@property
def ray_start_params(self):
Expand All @@ -102,12 +272,17 @@ def ray_start_params(self):
"""
return self._ray_start_params

@property
def k8s_pod(self):
return self._k8s_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
)

@classmethod
Expand All @@ -118,6 +293,7 @@ def from_flyte_idl(cls, proto):
"""
return cls(
ray_start_params=proto.ray_start_params,
k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None,
)


Expand Down
8 changes: 7 additions & 1 deletion plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins
from flytekit.models.task import K8sPod

ray = lazy_module("ray")


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None


@dataclass
Expand All @@ -35,6 +37,7 @@ class WorkerNodeConfig:
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None


@dataclass
Expand Down Expand Up @@ -89,7 +92,9 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod)
if cfg.head_node_config
else None
),
worker_group_spec=[
WorkerGroupSpec(
Expand All @@ -98,6 +103,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
c.min_replicas,
c.max_replicas,
c.ray_start_params,
c.k8s_pod,
)
for c in cfg.worker_node_config
],
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-ray/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.1.10"]
plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.13.6"]

__version__ = "0.0.0+develop"

Expand Down
8 changes: 5 additions & 3 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

import ray
import yaml
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray import HeadNodeConfig
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, K8sPod, HeadGroupSpec
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)],
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))],
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -41,7 +43,7 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True),
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
shutdown_after_job_finishes=True,
Expand Down

0 comments on commit c621359

Please sign in to comment.