Skip to content

Commit

Permalink
Improve console chat launch speed by defer unnecessary package loading (
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Q authored Aug 21, 2024
1 parent 69089c7 commit 18e8f8b
Show file tree
Hide file tree
Showing 41 changed files with 554 additions and 213 deletions.
5 changes: 5 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[pytest]
markers =
app_config: mark a test that requires the app config
testpaths =
tests
33 changes: 3 additions & 30 deletions taskweaver/app/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from os import listdir, path
from os import path
from typing import Any, Dict, Optional, Tuple

from injector import Injector
Expand All @@ -9,8 +9,6 @@
from taskweaver.memory.plugin import PluginModule
from taskweaver.module.execution_service import ExecutionServiceModule
from taskweaver.role.role import RoleModule

# if TYPE_CHECKING:
from taskweaver.session.session import Session


Expand Down Expand Up @@ -77,34 +75,9 @@ def discover_app_dir(
"""
Discover the app directory from the given path or the current working directory.
"""
from taskweaver.utils.app_utils import discover_app_dir

def validate_app_config(workspace: str) -> bool:
config_path = path.join(workspace, "taskweaver_config.json")
if not path.exists(config_path):
return False
# TODO: read, parse and validate config
return True

def is_dir_valid(dir: str) -> bool:
return path.exists(dir) and path.isdir(dir) and validate_app_config(dir)

def is_empty(dir: str) -> bool:
return not path.exists(dir) or (path.isdir(dir) and len(listdir(dir)) == 0)

if app_dir is not None:
app_dir = path.abspath(app_dir)
return app_dir, is_dir_valid(app_dir), is_empty(app_dir)
else:
cwd = path.abspath(".")
cur_dir = cwd
while True:
if is_dir_valid(cur_dir):
return cur_dir, True, False

next_path = path.abspath(path.join(cur_dir, ".."))
if next_path == cur_dir:
return cwd, False, is_empty(cwd)
cur_dir = next_path
return discover_app_dir(app_dir)

def _init_app_modules(self) -> None:
from taskweaver.llm import LLMApi
Expand Down
12 changes: 9 additions & 3 deletions taskweaver/ces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from typing import Literal

from taskweaver.ces.common import Manager
from taskweaver.ces.environment import Environment, EnvMode
from taskweaver.ces.manager.defer import DeferredManager
from taskweaver.ces.manager.sub_proc import SubProcessManager


def code_execution_service_factory(
env_dir: str,
kernel_mode: Literal["local", "container"] = "local",
) -> Manager:
return SubProcessManager(
env_dir=env_dir,
def sub_proc_manager_factory() -> SubProcessManager:
return SubProcessManager(
env_dir=env_dir,
kernel_mode=kernel_mode,
)

return DeferredManager(
kernel_mode=kernel_mode,
manager_factory=sub_proc_manager_factory,
)
5 changes: 4 additions & 1 deletion taskweaver/ces/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
...


KernelModeType = Literal["local", "container"]


class Manager(ABC):
"""
Manager is the interface for the execution manager.
Expand All @@ -128,5 +131,5 @@ def get_session_client(
...

@abstractmethod
def get_kernel_mode(self) -> Literal["local", "container"] | None:
def get_kernel_mode(self) -> KernelModeType:
...
74 changes: 74 additions & 0 deletions taskweaver/ces/manager/defer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from typing import Callable, Dict, Optional

from taskweaver.ces.common import Client, ExecutionResult, KernelModeType, Manager


class DeferredClient(Client):
def __init__(self, client_factory: Callable[[], Client]) -> None:
self.client_factory = client_factory
self.proxy_client: Optional[Client] = None

def start(self) -> None:
# defer the start to the proxy client
pass

def stop(self) -> None:
if self.proxy_client is not None:
self.proxy_client.stop()

def load_plugin(self, plugin_name: str, plugin_code: str, plugin_config: Dict[str, str]) -> None:
self._get_proxy_client().load_plugin(plugin_name, plugin_code, plugin_config)

def test_plugin(self, plugin_name: str) -> None:
self._get_proxy_client().test_plugin(plugin_name)

def update_session_var(self, session_var_dict: Dict[str, str]) -> None:
self._get_proxy_client().update_session_var(session_var_dict)

def execute_code(self, exec_id: str, code: str) -> ExecutionResult:
return self._get_proxy_client().execute_code(exec_id, code)

def _get_proxy_client(self) -> Client:
if self.proxy_client is None:
self.proxy_client = self.client_factory()
self.proxy_client.start()
return self.proxy_client


class DeferredManager(Manager):
def __init__(self, kernel_mode: KernelModeType, manager_factory: Callable[[], Manager]) -> None:
super().__init__()
self.kernel_mode: KernelModeType = kernel_mode
self.manager_factory = manager_factory
self.proxy_manager: Optional[Manager] = None

def initialize(self) -> None:
# defer the initialization to the proxy manager
pass

def clean_up(self) -> None:
if self.proxy_manager is not None:
self.proxy_manager.clean_up()

def get_session_client(
self,
session_id: str,
env_id: Optional[str] = None,
session_dir: Optional[str] = None,
cwd: Optional[str] = None,
) -> DeferredClient:
def client_factory() -> Client:
return self._get_proxy_manager().get_session_client(session_id, env_id, session_dir, cwd)

return DeferredClient(client_factory)

def get_kernel_mode(self) -> KernelModeType:
return self.kernel_mode

def _get_proxy_manager(self) -> Manager:
if self.proxy_manager is None:
self.proxy_manager = self.manager_factory()
self.proxy_manager.initialize()
return self.proxy_manager
13 changes: 7 additions & 6 deletions taskweaver/ces/manager/sub_proc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import os
from typing import Dict, Literal, Optional
from typing import Dict, Optional

from taskweaver.ces import Environment, EnvMode
from taskweaver.ces.common import Client, ExecutionResult, Manager
from taskweaver.ces.common import Client, ExecutionResult, KernelModeType, Manager


class SubProcessClient(Client):
Expand Down Expand Up @@ -57,14 +56,16 @@ def __init__(
self,
env_id: Optional[str] = None,
env_dir: Optional[str] = None,
kernel_mode: Optional[Literal["local", "container"]] = "local",
kernel_mode: KernelModeType = "local",
) -> None:
from taskweaver.ces.environment import Environment, EnvMode

env_id = env_id or os.getenv("TASKWEAVER_ENV_ID", "local")
env_dir = env_dir or os.getenv(
"TASKWEAVER_ENV_DIR",
os.path.realpath(os.getcwd()),
)
self.kernel_mode = kernel_mode
self.kernel_mode: KernelModeType = kernel_mode
if self.kernel_mode == "local":
env_mode = EnvMode.Local
elif self.kernel_mode == "container":
Expand Down Expand Up @@ -102,5 +103,5 @@ def get_session_client(
cwd=cwd,
)

def get_kernel_mode(self) -> Literal["local", "container"] | None:
def get_kernel_mode(self) -> KernelModeType:
return self.kernel_mode
18 changes: 12 additions & 6 deletions taskweaver/chat/console/chat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

import atexit
import shutil
import threading
import time
from textwrap import TextWrapper, dedent
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple

import click
from colorama import ansi

from taskweaver.app.app import TaskWeaverApp
from taskweaver.memory.attachment import AttachmentType
from taskweaver.module.event_emitter import PostEventType, RoundEventType, SessionEventHandlerBase, SessionEventType
from taskweaver.session.session import Session

if TYPE_CHECKING:
from taskweaver.memory.attachment import AttachmentType
from taskweaver.session.session import Session


def error_message(message: str) -> None:
Expand Down Expand Up @@ -272,6 +274,8 @@ def wrap_message(
return "\n".join(result)

def clear_line():
from colorama import ansi

print(ansi.clear_line(), end="\r")

def get_ani_frame(frame: int = 0):
Expand Down Expand Up @@ -405,6 +409,8 @@ def format_status_message(limit: int):

class TaskWeaverChatApp(SessionEventHandlerBase):
def __init__(self, app_dir: Optional[str] = None):
from taskweaver.app.app import TaskWeaverApp

self.app = TaskWeaverApp(app_dir=app_dir, use_local_uri=True)
self.session = self.app.get_session()
self.pending_files: List[Dict[Literal["name", "path", "content"], Any]] = []
Expand Down Expand Up @@ -436,7 +442,7 @@ def _process_user_input(self, user_input: str) -> None:
if lower_command == "reset":
self._reset_session()
return
if lower_command in ["load", "file"]:
if lower_command in ["load", "file", "img", "image"]:
file_to_load = msg[5:].strip()
self._load_file(file_to_load)
return
Expand Down
10 changes: 5 additions & 5 deletions taskweaver/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import click

from ..app import TaskWeaverApp
from .chat import chat
from .init import init
from .util import CliContext, get_ascii_banner
from .web import web


@click.group(
name="taskweaver",
help=f"\b\n{get_ascii_banner()}\nTaskWeaver",
help=f"\b\n{get_ascii_banner(center=False)}\nTaskWeaver",
invoke_without_command=True,
commands=[init, chat, web],
commands=[init, chat],
)
@click.pass_context
@click.version_option(package_name="taskweaver")
Expand All @@ -28,7 +26,9 @@
default=None,
)
def taskweaver(ctx: click.Context, project: str):
workspace_base, is_valid, is_empty = TaskWeaverApp.discover_app_dir(project)
from taskweaver.utils.app_utils import discover_app_dir

workspace_base, is_valid, is_empty = discover_app_dir(project)

# subcommand_target = ctx.invoked_subcommand if ctx.invoked_subcommand is not None else "chat"

Expand Down
6 changes: 5 additions & 1 deletion taskweaver/cli/init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil
from typing import Any

import click
Expand Down Expand Up @@ -70,6 +69,8 @@ def init(
zip_ref.extractall(tpl_dir)
copy_files(os.path.join(tpl_dir, "project"), project)
try:
import shutil

shutil.rmtree(tpl_dir)
except Exception:
click.secho("Failed to remove temporary directory", fg="yellow")
Expand All @@ -83,5 +84,8 @@ def copy_files(src_dir: str, dst_dir: str):
# Check if the destination folder exists. If not, create it.
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)

import shutil

# Copy the content of source_folder to destination_folder
shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True)
17 changes: 15 additions & 2 deletions taskweaver/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@ class CliContext:
is_workspace_empty: bool


def get_ascii_banner() -> str:
return dedent(
def center_cli_str(text: str, width: Optional[int] = None):
import shutil

width = width or shutil.get_terminal_size().columns
lines = text.split("\n")
max_line_len = max(len(line) for line in lines)
return "\n".join((line + " " * (max_line_len - len(line))).center(width) for line in lines)


def get_ascii_banner(center: bool = True) -> str:
text = dedent(
r"""
=========================================================
_____ _ _ __
Expand All @@ -47,3 +56,7 @@ def get_ascii_banner() -> str:
=========================================================
""",
).strip()
if center:
return center_cli_str(text)
else:
return text
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
experience_generator: ExperienceGenerator,
):
super().__init__(config, logger, tracing, event_emitter)
self.config = config
self.llm_api = llm_api

self.role_name = self.config.role_name
Expand Down Expand Up @@ -262,10 +263,7 @@ def compose_conversation(
supplementary_info_dict = conversation_round.read_board()
supplementary_info = "\n\n".join([bulletin for bulletin in supplementary_info_dict.values()])
if supplementary_info != "":
enrichment += (
f"Additional context:\n"
f" {supplementary_info}\n\n"
)
enrichment += f"Additional context:\n" f" {supplementary_info}\n\n"

user_feedback = "None"
if last_post is not None and last_post.send_from == self.alias:
Expand Down Expand Up @@ -352,6 +350,7 @@ def reply(
memory: Memory,
post_proxy: Optional[PostEventProxy] = None,
prompt_log_path: Optional[str] = None,
**kwargs: ...,
) -> Post:
assert post_proxy is not None, "Post proxy is not provided."

Expand Down Expand Up @@ -425,7 +424,7 @@ def early_stop(_type: AttachmentType, value: str) -> bool:
self.selected_plugin_pool.filter_unused_plugins(code=generated_code)

if prompt_log_path is not None:
self.logger.dump_log_file(prompt, prompt_log_path)
self.logger.dump_prompt_file(prompt, prompt_log_path)

self.tracing.set_span_attribute("code", generated_code)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def reply(
self,
memory: Memory,
prompt_log_path: Optional[str] = None,
**kwargs: ...,
) -> Post:
post_proxy = self.event_emitter.create_post_proxy(self.alias)
post_proxy.update_status("generating code")
Expand Down
Loading

0 comments on commit 18e8f8b

Please sign in to comment.