From aa23765c17d35887a82e91b56c722b30c30387cc Mon Sep 17 00:00:00 2001 From: Jack Q Date: Fri, 30 Aug 2024 16:31:31 +0800 Subject: [PATCH] load ces async (#399) --- taskweaver/ces/common.py | 5 +- taskweaver/ces/environment.py | 25 +++- taskweaver/ces/manager/defer.py | 126 +++++++++++++++--- taskweaver/code_interpreter/code_executor.py | 14 +- .../code_interpreter/code_interpreter.py | 1 + .../code_interpreter_cli_only.py | 1 + .../code_interpreter_plugin_only.py | 1 + taskweaver/utils/time_usage.py | 24 ++++ 8 files changed, 161 insertions(+), 36 deletions(-) create mode 100644 taskweaver/utils/time_usage.py diff --git a/taskweaver/ces/common.py b/taskweaver/ces/common.py index 8270452e..ebcd89fe 100644 --- a/taskweaver/ces/common.py +++ b/taskweaver/ces/common.py @@ -4,9 +4,10 @@ import secrets from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union -from taskweaver.plugin.context import ArtifactType +if TYPE_CHECKING: + from taskweaver.plugin.context import ArtifactType @dataclass diff --git a/taskweaver/ces/environment.py b/taskweaver/ces/environment.py index a099cc42..2eb5adc1 100644 --- a/taskweaver/ces/environment.py +++ b/taskweaver/ces/environment.py @@ -122,6 +122,7 @@ def __init__( self.id = get_id(prefix="env") if env_id is None else env_id self.env_dir = env_dir if env_dir is not None else os.getcwd() self.mode = env_mode + self._client: Optional[BlockingKernelClient] = None if self.mode == EnvMode.Local: self.multi_kernel_manager = TaskWeaverMultiKernelManager( @@ -384,6 +385,7 @@ def update_session_var( session.session_var.update(session_var) def stop_session(self, session_id: str) -> None: + self._clean_client() session = self._get_session(session_id) if session is None: # session not exist @@ -477,6 +479,8 @@ def _get_client( self, session_id: str, ) -> BlockingKernelClient: + if self._client is not None: + return self._client session = self._get_session(session_id) connection_file = self._get_connection_file(session_id, session.kernel_id) client = BlockingKernelClient(connection_file=connection_file) @@ -490,8 +494,16 @@ def _get_client( client.hb_port = ports["hb_port"] client.control_port = ports["control_port"] client.iopub_port = ports["iopub_port"] + client.wait_for_ready(timeout=30) + client.start_channels() + self._client = client return client + def _clean_client(self): + if self._client is not None: + self._client.stop_channels() + self._client = None + def _execute_code_on_kernel( self, session_id: str, @@ -503,8 +515,6 @@ def _execute_code_on_kernel( ) -> EnvExecution: exec_result = EnvExecution(exec_id=exec_id, code=code, exec_type=exec_type) kc = self._get_client(session_id) - kc.wait_for_ready(timeout=30) - kc.start_channels() result_msg_id = kc.execute( code=code, silent=silent, @@ -515,11 +525,16 @@ def _execute_code_on_kernel( try: # TODO: interrupt kernel if it takes too long while True: - message = kc.get_iopub_msg(timeout=180) + from taskweaver.utils.time_usage import time_usage + with time_usage() as time_msg: + message = kc.get_iopub_msg(timeout=180) + logger.debug((f"Time: {time_msg.total:.2f} \t MsgType: {message['msg_type']} \t Code: {code}")) logger.debug(json.dumps(message, indent=2, default=str)) - assert message["parent_header"]["msg_id"] == result_msg_id + if message["parent_header"]["msg_id"] != result_msg_id: + # skip messages not related to the current execution + continue msg_type = message["msg_type"] if msg_type == "status": if message["content"]["execution_state"] == "idle": @@ -565,7 +580,7 @@ def _execute_code_on_kernel( else: pass finally: - kc.stop_channels() + pass return exec_result def _update_session_var(self, session: EnvSession) -> None: diff --git a/taskweaver/ces/manager/defer.py b/taskweaver/ces/manager/defer.py index e3c5fa36..5b9fd0a0 100644 --- a/taskweaver/ces/manager/defer.py +++ b/taskweaver/ces/manager/defer.py @@ -1,24 +1,80 @@ from __future__ import annotations -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple, TypeVar from taskweaver.ces.common import Client, ExecutionResult, KernelModeType, Manager +TaskResult = TypeVar("TaskResult") + + +def deferred_var( + name: str, + init: Callable[[], TaskResult], + threaded: bool, +) -> Callable[[], TaskResult]: + result: Optional[Tuple[TaskResult]] = None + if not threaded: + result = (init(),) + + def sync_result() -> TaskResult: + assert result is not None + return result[0] + + return sync_result + + import threading + + lock = threading.Lock() + loaded_event = threading.Event() + thread: Optional[threading.Thread] = None + + def task() -> None: + nonlocal result + result = (init(),) + loaded_event.set() + + def async_result() -> TaskResult: + nonlocal result, thread + loaded_event.wait() + with lock: + if thread is not None: + thread.join() + thread = None + + assert result is not None + return result[0] + + with lock: + threading.Thread(target=task, daemon=True).start() + + return async_result + class DeferredClient(Client): - def __init__(self, client_factory: Callable[[], Client]) -> None: + def __init__( + self, + client_factory: Callable[[], Client], + async_warm_up: bool = False, + ) -> None: self.client_factory = client_factory - self.proxy_client: Optional[Client] = None + self.async_warm_up = async_warm_up + self.deferred_var: Optional[Callable[[], Client]] = None def start(self) -> None: # defer the start to the proxy client - pass + if self.async_warm_up: + self._init_deferred_var() def stop(self) -> None: - if self.proxy_client is not None: - self.proxy_client.stop() + if self.deferred_var is not None: + self.deferred_var().stop() - def load_plugin(self, plugin_name: str, plugin_code: str, plugin_config: Dict[str, str]) -> None: + def load_plugin( + self, + plugin_name: str, + plugin_code: str, + plugin_config: Dict[str, str], + ) -> None: self._get_proxy_client().load_plugin(plugin_name, plugin_code, plugin_config) def test_plugin(self, plugin_name: str) -> None: @@ -31,26 +87,41 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult: return self._get_proxy_client().execute_code(exec_id, code) def _get_proxy_client(self) -> Client: - if self.proxy_client is None: - self.proxy_client = self.client_factory() - self.proxy_client.start() - return self.proxy_client + return self._init_deferred_var()() + + def _init_deferred_var(self) -> Callable[[], Client]: + if self.deferred_var is None: + + def task() -> Client: + client = self.client_factory() + client.start() + return client + + self.deferred_var = deferred_var("DeferredClient", task, self.async_warm_up) + return self.deferred_var class DeferredManager(Manager): - def __init__(self, kernel_mode: KernelModeType, manager_factory: Callable[[], Manager]) -> None: + def __init__( + self, + kernel_mode: KernelModeType, + manager_factory: Callable[[], Manager], + async_warm_up: bool = True, + ) -> None: super().__init__() self.kernel_mode: KernelModeType = kernel_mode self.manager_factory = manager_factory - self.proxy_manager: Optional[Manager] = None + self.async_warm_up = async_warm_up + self.deferred_var: Optional[Callable[[], Manager]] = None def initialize(self) -> None: # defer the initialization to the proxy manager - pass + if self.async_warm_up: + self._init_deferred_var() def clean_up(self) -> None: - if self.proxy_manager is not None: - self.proxy_manager.clean_up() + if self.deferred_var is not None: + self.deferred_var().clean_up() def get_session_client( self, @@ -60,15 +131,26 @@ def get_session_client( cwd: Optional[str] = None, ) -> DeferredClient: def client_factory() -> Client: - return self._get_proxy_manager().get_session_client(session_id, env_id, session_dir, cwd) + return self._get_proxy_manager().get_session_client( + session_id, + env_id, + session_dir, + cwd, + ) - return DeferredClient(client_factory) + return DeferredClient(client_factory, self.async_warm_up) def get_kernel_mode(self) -> KernelModeType: return self.kernel_mode def _get_proxy_manager(self) -> Manager: - if self.proxy_manager is None: - self.proxy_manager = self.manager_factory() - self.proxy_manager.initialize() - return self.proxy_manager + return self._init_deferred_var()() + + def _init_deferred_var(self) -> Callable[[], Manager]: + if self.deferred_var is None: + self.deferred_var = deferred_var( + "DeferredManager", + self.manager_factory, + self.async_warm_up, + ) + return self.deferred_var diff --git a/taskweaver/code_interpreter/code_executor.py b/taskweaver/code_interpreter/code_executor.py index ad8b696f..40974950 100644 --- a/taskweaver/code_interpreter/code_executor.py +++ b/taskweaver/code_interpreter/code_executor.py @@ -67,16 +67,14 @@ def __init__( @tracing_decorator def execute_code(self, exec_id: str, code: str) -> ExecutionResult: - if not self.client_started: - with get_tracer().start_as_current_span("start"): - self.start() - self.client_started = True + with get_tracer().start_as_current_span("start"): + self.start() if not self.plugin_loaded: with get_tracer().start_as_current_span("load_plugin"): self.load_plugin() self.plugin_loaded = True - + # update session variables self.exec_client.update_session_var(self.session_variables) @@ -111,7 +109,7 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult: ) return result - + def update_session_var(self, session_var_dict: dict) -> None: self.session_variables.update(session_var_dict) @@ -146,7 +144,9 @@ def load_plugin(self): print(f"Plugin {p.name} failed to load: {str(e)}") def start(self): - self.exec_client.start() + if not self.client_started: + self.exec_client.start() + self.client_started = True def stop(self): self.exec_client.stop() diff --git a/taskweaver/code_interpreter/code_interpreter/code_interpreter.py b/taskweaver/code_interpreter/code_interpreter/code_interpreter.py index e4d20e25..82dc9da6 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_interpreter.py +++ b/taskweaver/code_interpreter/code_interpreter/code_interpreter.py @@ -131,6 +131,7 @@ def reply( ) -> Post: post_proxy = self.event_emitter.create_post_proxy(self.alias) post_proxy.update_status("generating code") + self.executor.start() self.generator.reply( memory, post_proxy, diff --git a/taskweaver/code_interpreter/code_interpreter_cli_only/code_interpreter_cli_only.py b/taskweaver/code_interpreter/code_interpreter_cli_only/code_interpreter_cli_only.py index 48330949..910b48ec 100644 --- a/taskweaver/code_interpreter/code_interpreter_cli_only/code_interpreter_cli_only.py +++ b/taskweaver/code_interpreter/code_interpreter_cli_only/code_interpreter_cli_only.py @@ -53,6 +53,7 @@ def reply( **kwargs: ..., ) -> Post: post_proxy = self.event_emitter.create_post_proxy(self.alias) + self.executor.start() self.generator.reply( memory, post_proxy=post_proxy, diff --git a/taskweaver/code_interpreter/code_interpreter_plugin_only/code_interpreter_plugin_only.py b/taskweaver/code_interpreter/code_interpreter_plugin_only/code_interpreter_plugin_only.py index bf03dc6e..b1600be8 100644 --- a/taskweaver/code_interpreter/code_interpreter_plugin_only/code_interpreter_plugin_only.py +++ b/taskweaver/code_interpreter/code_interpreter_plugin_only/code_interpreter_plugin_only.py @@ -67,6 +67,7 @@ def reply( **kwargs: ..., ) -> Post: post_proxy = self.event_emitter.create_post_proxy(self.alias) + self.executor.start() self.generator.reply( memory, post_proxy=post_proxy, diff --git a/taskweaver/utils/time_usage.py b/taskweaver/utils/time_usage.py new file mode 100644 index 00000000..d2e5d65c --- /dev/null +++ b/taskweaver/utils/time_usage.py @@ -0,0 +1,24 @@ +import time +from contextlib import contextmanager +from dataclasses import dataclass + + +@dataclass() +class TimeUsage: + start: float + end: float + process: float + total: float + + +@contextmanager +def time_usage(): + usage = TimeUsage(start=time.time(), end=0, process=0, total=0) + perf_time_start = time.perf_counter_ns() + process_start = time.process_time_ns() + yield usage + process_end = time.process_time_ns() + perf_time_end = time.perf_counter_ns() + usage.end = time.time() + usage.process = round((process_end - process_start) / 1e6, 3) + usage.total = round((perf_time_end - perf_time_start) / 1e6, 3)