From 5be482f3d2e0f2ac5e17f77947a6a6dae803135a Mon Sep 17 00:00:00 2001 From: luke-lombardi <33990301+luke-lombardi@users.noreply.github.com> Date: Tue, 31 Dec 2024 15:33:54 -0500 Subject: [PATCH] add shell support everywhere --- sdk/src/beta9/abstractions/base/runner.py | 12 ++++++++++++ sdk/src/beta9/abstractions/endpoint.py | 8 -------- sdk/src/beta9/abstractions/mixins.py | 1 - 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sdk/src/beta9/abstractions/base/runner.py b/sdk/src/beta9/abstractions/base/runner.py index cf5deaeee..f58baa954 100644 --- a/sdk/src/beta9/abstractions/base/runner.py +++ b/sdk/src/beta9/abstractions/base/runner.py @@ -27,6 +27,7 @@ SecretVar, ) from ...clients.gateway import TaskPolicy as TaskPolicyProto +from ...clients.shell import ShellServiceStub from ...config import ConfigContext, SDKSettings, get_config_context, get_settings from ...env import called_on_import from ...sync import FileSyncer, SyncEventHandler @@ -144,6 +145,7 @@ def __init__( self._map_callable_to_attr(attr="on_start", func=on_start) self._gateway_stub: Optional[GatewayServiceStub] = None + self._shell_stub: Optional[ShellServiceStub] = None self.syncer: FileSyncer = FileSyncer(self.gateway_stub) self.settings: SDKSettings = get_settings() self.config_context: ConfigContext = get_config_context() @@ -212,6 +214,16 @@ def gateway_stub(self) -> GatewayServiceStub: def gateway_stub(self, value) -> None: self._gateway_stub = value + @property + def shell_stub(self) -> ShellServiceStub: + if not self._shell_stub: + self._shell_stub = ShellServiceStub(self.channel) + return self._shell_stub + + @shell_stub.setter + def shell_stub(self, value) -> None: + self._shell_stub = value + def _parse_cpu_to_millicores(self, cpu: Union[float, str]) -> int: """ Parse the cpu argument to an integer value in millicores. diff --git a/sdk/src/beta9/abstractions/endpoint.py b/sdk/src/beta9/abstractions/endpoint.py index 623c67eb9..b25651b59 100644 --- a/sdk/src/beta9/abstractions/endpoint.py +++ b/sdk/src/beta9/abstractions/endpoint.py @@ -26,7 +26,6 @@ StartEndpointServeResponse, StopEndpointServeRequest, ) -from ..clients.shell import ShellServiceStub from ..env import is_local from ..type import Autoscaler, GpuType, GpuTypeAlias, QueueDepthAutoscaler, TaskPolicy from .mixins import DeployableMixin @@ -160,7 +159,6 @@ def __init__( ) self._endpoint_stub: Optional[EndpointServiceStub] = None - self._shell_stub: Optional[ShellServiceStub] = None @property def endpoint_stub(self) -> EndpointServiceStub: @@ -168,12 +166,6 @@ def endpoint_stub(self) -> EndpointServiceStub: self._endpoint_stub = EndpointServiceStub(self.channel) return self._endpoint_stub - @property - def shell_stub(self) -> ShellServiceStub: - if not self._shell_stub: - self._shell_stub = ShellServiceStub(self.channel) - return self._shell_stub - def __call__(self, func): return _CallableWrapper(func, self) diff --git a/sdk/src/beta9/abstractions/mixins.py b/sdk/src/beta9/abstractions/mixins.py index 79d545279..714e80551 100644 --- a/sdk/src/beta9/abstractions/mixins.py +++ b/sdk/src/beta9/abstractions/mixins.py @@ -16,7 +16,6 @@ class DeployableMixin: func: Callable parent: RunnerAbstraction deployment_id: Optional[str] = None - deployment_stub_type: ClassVar[str] def _validate(self):