Skip to content

Commit

Permalink
load ces async (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Q authored Aug 30, 2024
1 parent 49cf816 commit aa23765
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 36 deletions.
5 changes: 3 additions & 2 deletions taskweaver/ces/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import secrets
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

from taskweaver.plugin.context import ArtifactType
if TYPE_CHECKING:
from taskweaver.plugin.context import ArtifactType


@dataclass
Expand Down
25 changes: 20 additions & 5 deletions taskweaver/ces/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
self.id = get_id(prefix="env") if env_id is None else env_id
self.env_dir = env_dir if env_dir is not None else os.getcwd()
self.mode = env_mode
self._client: Optional[BlockingKernelClient] = None

if self.mode == EnvMode.Local:
self.multi_kernel_manager = TaskWeaverMultiKernelManager(
Expand Down Expand Up @@ -384,6 +385,7 @@ def update_session_var(
session.session_var.update(session_var)

def stop_session(self, session_id: str) -> None:
self._clean_client()
session = self._get_session(session_id)
if session is None:
# session not exist
Expand Down Expand Up @@ -477,6 +479,8 @@ def _get_client(
self,
session_id: str,
) -> BlockingKernelClient:
if self._client is not None:
return self._client
session = self._get_session(session_id)
connection_file = self._get_connection_file(session_id, session.kernel_id)
client = BlockingKernelClient(connection_file=connection_file)
Expand All @@ -490,8 +494,16 @@ def _get_client(
client.hb_port = ports["hb_port"]
client.control_port = ports["control_port"]
client.iopub_port = ports["iopub_port"]
client.wait_for_ready(timeout=30)
client.start_channels()
self._client = client
return client

def _clean_client(self):
if self._client is not None:
self._client.stop_channels()
self._client = None

def _execute_code_on_kernel(
self,
session_id: str,
Expand All @@ -503,8 +515,6 @@ def _execute_code_on_kernel(
) -> EnvExecution:
exec_result = EnvExecution(exec_id=exec_id, code=code, exec_type=exec_type)
kc = self._get_client(session_id)
kc.wait_for_ready(timeout=30)
kc.start_channels()
result_msg_id = kc.execute(
code=code,
silent=silent,
Expand All @@ -515,11 +525,16 @@ def _execute_code_on_kernel(
try:
# TODO: interrupt kernel if it takes too long
while True:
message = kc.get_iopub_msg(timeout=180)
from taskweaver.utils.time_usage import time_usage

with time_usage() as time_msg:
message = kc.get_iopub_msg(timeout=180)
logger.debug((f"Time: {time_msg.total:.2f} \t MsgType: {message['msg_type']} \t Code: {code}"))
logger.debug(json.dumps(message, indent=2, default=str))

assert message["parent_header"]["msg_id"] == result_msg_id
if message["parent_header"]["msg_id"] != result_msg_id:
# skip messages not related to the current execution
continue
msg_type = message["msg_type"]
if msg_type == "status":
if message["content"]["execution_state"] == "idle":
Expand Down Expand Up @@ -565,7 +580,7 @@ def _execute_code_on_kernel(
else:
pass
finally:
kc.stop_channels()
pass
return exec_result

def _update_session_var(self, session: EnvSession) -> None:
Expand Down
126 changes: 104 additions & 22 deletions taskweaver/ces/manager/defer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,80 @@
from __future__ import annotations

from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Tuple, TypeVar

from taskweaver.ces.common import Client, ExecutionResult, KernelModeType, Manager

TaskResult = TypeVar("TaskResult")


def deferred_var(
name: str,
init: Callable[[], TaskResult],
threaded: bool,
) -> Callable[[], TaskResult]:
result: Optional[Tuple[TaskResult]] = None
if not threaded:
result = (init(),)

def sync_result() -> TaskResult:
assert result is not None
return result[0]

return sync_result

import threading

lock = threading.Lock()
loaded_event = threading.Event()
thread: Optional[threading.Thread] = None

def task() -> None:
nonlocal result
result = (init(),)
loaded_event.set()

def async_result() -> TaskResult:
nonlocal result, thread
loaded_event.wait()
with lock:
if thread is not None:
thread.join()
thread = None

assert result is not None
return result[0]

with lock:
threading.Thread(target=task, daemon=True).start()

return async_result


class DeferredClient(Client):
def __init__(self, client_factory: Callable[[], Client]) -> None:
def __init__(
self,
client_factory: Callable[[], Client],
async_warm_up: bool = False,
) -> None:
self.client_factory = client_factory
self.proxy_client: Optional[Client] = None
self.async_warm_up = async_warm_up
self.deferred_var: Optional[Callable[[], Client]] = None

def start(self) -> None:
# defer the start to the proxy client
pass
if self.async_warm_up:
self._init_deferred_var()

def stop(self) -> None:
if self.proxy_client is not None:
self.proxy_client.stop()
if self.deferred_var is not None:
self.deferred_var().stop()

def load_plugin(self, plugin_name: str, plugin_code: str, plugin_config: Dict[str, str]) -> None:
def load_plugin(
self,
plugin_name: str,
plugin_code: str,
plugin_config: Dict[str, str],
) -> None:
self._get_proxy_client().load_plugin(plugin_name, plugin_code, plugin_config)

def test_plugin(self, plugin_name: str) -> None:
Expand All @@ -31,26 +87,41 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
return self._get_proxy_client().execute_code(exec_id, code)

def _get_proxy_client(self) -> Client:
if self.proxy_client is None:
self.proxy_client = self.client_factory()
self.proxy_client.start()
return self.proxy_client
return self._init_deferred_var()()

def _init_deferred_var(self) -> Callable[[], Client]:
if self.deferred_var is None:

def task() -> Client:
client = self.client_factory()
client.start()
return client

self.deferred_var = deferred_var("DeferredClient", task, self.async_warm_up)
return self.deferred_var


class DeferredManager(Manager):
def __init__(self, kernel_mode: KernelModeType, manager_factory: Callable[[], Manager]) -> None:
def __init__(
self,
kernel_mode: KernelModeType,
manager_factory: Callable[[], Manager],
async_warm_up: bool = True,
) -> None:
super().__init__()
self.kernel_mode: KernelModeType = kernel_mode
self.manager_factory = manager_factory
self.proxy_manager: Optional[Manager] = None
self.async_warm_up = async_warm_up
self.deferred_var: Optional[Callable[[], Manager]] = None

def initialize(self) -> None:
# defer the initialization to the proxy manager
pass
if self.async_warm_up:
self._init_deferred_var()

def clean_up(self) -> None:
if self.proxy_manager is not None:
self.proxy_manager.clean_up()
if self.deferred_var is not None:
self.deferred_var().clean_up()

def get_session_client(
self,
Expand All @@ -60,15 +131,26 @@ def get_session_client(
cwd: Optional[str] = None,
) -> DeferredClient:
def client_factory() -> Client:
return self._get_proxy_manager().get_session_client(session_id, env_id, session_dir, cwd)
return self._get_proxy_manager().get_session_client(
session_id,
env_id,
session_dir,
cwd,
)

return DeferredClient(client_factory)
return DeferredClient(client_factory, self.async_warm_up)

def get_kernel_mode(self) -> KernelModeType:
return self.kernel_mode

def _get_proxy_manager(self) -> Manager:
if self.proxy_manager is None:
self.proxy_manager = self.manager_factory()
self.proxy_manager.initialize()
return self.proxy_manager
return self._init_deferred_var()()

def _init_deferred_var(self) -> Callable[[], Manager]:
if self.deferred_var is None:
self.deferred_var = deferred_var(
"DeferredManager",
self.manager_factory,
self.async_warm_up,
)
return self.deferred_var
14 changes: 7 additions & 7 deletions taskweaver/code_interpreter/code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,14 @@ def __init__(

@tracing_decorator
def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
if not self.client_started:
with get_tracer().start_as_current_span("start"):
self.start()
self.client_started = True
with get_tracer().start_as_current_span("start"):
self.start()

if not self.plugin_loaded:
with get_tracer().start_as_current_span("load_plugin"):
self.load_plugin()
self.plugin_loaded = True

# update session variables
self.exec_client.update_session_var(self.session_variables)

Expand Down Expand Up @@ -111,7 +109,7 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
)

return result

def update_session_var(self, session_var_dict: dict) -> None:
self.session_variables.update(session_var_dict)

Expand Down Expand Up @@ -146,7 +144,9 @@ def load_plugin(self):
print(f"Plugin {p.name} failed to load: {str(e)}")

def start(self):
self.exec_client.start()
if not self.client_started:
self.exec_client.start()
self.client_started = True

def stop(self):
self.exec_client.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def reply(
) -> Post:
post_proxy = self.event_emitter.create_post_proxy(self.alias)
post_proxy.update_status("generating code")
self.executor.start()
self.generator.reply(
memory,
post_proxy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def reply(
**kwargs: ...,
) -> Post:
post_proxy = self.event_emitter.create_post_proxy(self.alias)
self.executor.start()
self.generator.reply(
memory,
post_proxy=post_proxy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def reply(
**kwargs: ...,
) -> Post:
post_proxy = self.event_emitter.create_post_proxy(self.alias)
self.executor.start()
self.generator.reply(
memory,
post_proxy=post_proxy,
Expand Down
24 changes: 24 additions & 0 deletions taskweaver/utils/time_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import time
from contextlib import contextmanager
from dataclasses import dataclass


@dataclass()
class TimeUsage:
start: float
end: float
process: float
total: float


@contextmanager
def time_usage():
usage = TimeUsage(start=time.time(), end=0, process=0, total=0)
perf_time_start = time.perf_counter_ns()
process_start = time.process_time_ns()
yield usage
process_end = time.process_time_ns()
perf_time_end = time.perf_counter_ns()
usage.end = time.time()
usage.process = round((process_end - process_start) / 1e6, 3)
usage.total = round((perf_time_end - perf_time_start) / 1e6, 3)

0 comments on commit aa23765

Please sign in to comment.