Skip to content

Commit

Permalink
Decouple ray submitter, worker, and head resources
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced committed Nov 15, 2024
1 parent e19bbcc commit 185ba39
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 7 deletions.
2 changes: 1 addition & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ def data_config(self) -> typing.Optional[DataLoadingConfig]:

def to_flyte_idl(self) -> _core_task.K8sPod:
return _core_task.K8sPod(
metadata=self._metadata.to_flyte_idl(),
metadata=self._metadata.to_flyte_idl() if self.metadata else None,
pod_spec=_json_format.Parse(_json.dumps(self.pod_spec), _struct.Struct()) if self.pod_spec else None,
data_config=self.data_config.to_flyte_idl() if self.data_config else None,
)
Expand Down
27 changes: 26 additions & 1 deletion plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class WorkerGroupSpec(_common.FlyteIdlEntity):
Expand All @@ -13,12 +14,14 @@ def __init__(
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
):
self._group_name = group_name
self._replicas = replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod

@property
def group_name(self):
Expand Down Expand Up @@ -60,6 +63,14 @@ def ray_start_params(self):
"""
return self._ray_start_params

@property
def k8s_pod(self):
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._k8s_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.WorkerGroupSpec
Expand All @@ -70,6 +81,7 @@ def to_flyte_idl(self):
min_replicas=self.min_replicas,
max_replicas=self.max_replicas,
ray_start_params=self.ray_start_params,
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
)

@classmethod
Expand All @@ -84,30 +96,42 @@ def from_flyte_idl(cls, proto):
min_replicas=proto.min_replicas,
max_replicas=proto.max_replicas,
ray_start_params=proto.ray_start_params,
k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None,
)


class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
):
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod

@property
def ray_start_params(self):
"""
The ray start params of worker node group.
The ray start params of head node group.
:rtype: typing.Dict[str, str]
"""
return self._ray_start_params

@property
def k8s_pod(self):
"""
Additional pod specs for the head node pod.
:rtype: K8sPod
"""
return self._k8s_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins._ray_pb2.HeadGroupSpec
"""
return _ray_pb2.HeadGroupSpec(
ray_start_params=self.ray_start_params if self.ray_start_params else {},
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
)

@classmethod
Expand All @@ -118,6 +142,7 @@ def from_flyte_idl(cls, proto):
"""
return cls(
ray_start_params=proto.ray_start_params,
k8s_pod=K8sPod.from_flyte_idl(proto.k8s_pod) if proto.HasField("k8s_pod") else None,
)


Expand Down
8 changes: 7 additions & 1 deletion plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.extend import TaskPlugins
from flytekit.models.task import K8sPod

ray = lazy_module("ray")


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None


@dataclass
Expand All @@ -35,6 +37,7 @@ class WorkerNodeConfig:
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None


@dataclass
Expand Down Expand Up @@ -89,7 +92,9 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params) if cfg.head_node_config else None
HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod)
if cfg.head_node_config
else None
),
worker_group_spec=[
WorkerGroupSpec(
Expand All @@ -98,6 +103,7 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
c.min_replicas,
c.max_replicas,
c.ray_start_params,
c.k8s_pod,
)
for c in cfg.worker_node_config
],
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-ray/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.1.10"]
plugin_requires = ["ray[default]", "flytekit>=1.3.0b2,<2.0.0", "flyteidl>=1.13.6"]

__version__ = "0.0.0+develop"

Expand Down
9 changes: 6 additions & 3 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

import ray
import yaml
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec
from flytekitplugins.ray import HeadNodeConfig
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict
from flytekit.models.task import K8sPod

from flytekit import PythonFunctionTask, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)],
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))],
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -41,7 +44,7 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)], enable_autoscaling=True),
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
shutdown_after_job_finishes=True,
Expand Down

0 comments on commit 185ba39

Please sign in to comment.