From 1d2b229bce10972c371bb7b89f34d234f851669b Mon Sep 17 00:00:00 2001 From: liqun Date: Fri, 6 Sep 2024 17:51:16 +0800 Subject: [PATCH] init --- .../code_interpreter/code_generator.py | 37 +++++++++--- taskweaver/ext_role/echo/echo.py | 6 ++ taskweaver/memory/attachment.py | 4 ++ taskweaver/memory/experience.py | 58 +++++++++++++------ taskweaver/planner/planner.py | 41 ++++++++++--- taskweaver/session/session.py | 40 +++++++++++-- tests/unit_tests/test_experience.py | 2 +- 7 files changed, 147 insertions(+), 41 deletions(-) diff --git a/taskweaver/code_interpreter/code_interpreter/code_generator.py b/taskweaver/code_interpreter/code_interpreter/code_generator.py index f72e1ecd..7011c4f3 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_generator.py +++ b/taskweaver/code_interpreter/code_interpreter/code_generator.py @@ -56,6 +56,7 @@ def _configure(self) -> None: self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3) self.use_experience = self._get_bool("use_experience", False) + self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False) self.llm_alias = self._get_str("llm_alias", default="", required=False) @@ -104,17 +105,31 @@ def __init__( logger.info("Plugin embeddings loaded") self.selected_plugin_pool = SelectedPluginPool() - if self.config.use_experience: - self.experience_generator = experience_generator - self.experience_generator.refresh() - self.experience_generator.load_experience() - self.logger.info( - "Experience loaded successfully, " - "there are {} experiences".format(len(self.experience_generator.experience_list)), - ) + self.experience_generator = experience_generator + self.exp_loaded = False + if self.config.dynamic_experience_filter: + self.exp_filter_str = None + else: + # use the experience folder + self.exp_filter_str = "" self.logger.info("CodeGenerator initialized successfully") + def load_experience(self): + if self.exp_filter_str is None or self.exp_loaded: + return + self.experience_generator.set_sub_dir(self.exp_filter_str) + self.experience_generator.refresh() + self.experience_generator.load_experience() + self.logger.info( + "Experience loaded successfully, " + "there are {} experiences with filter [{}]".format( + len(self.experience_generator.experience_list), + self.exp_filter_str, + ) + ) + self.exp_loaded = True + def configure_verification( self, code_verification_on: bool, @@ -363,6 +378,11 @@ def reply( # obtain the query from the last round query = rounds[-1].post_list[-1].message + exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter) + if exp_filter: + self.exp_filter_str = exp_filter[0].content + self.tracing.set_span_attribute("exp_filter", self.exp_filter_str) + self.tracing.set_span_attribute("query", query) self.tracing.set_span_attribute("enable_auto_plugin_selection", self.config.enable_auto_plugin_selection) self.tracing.set_span_attribute("use_experience", self.config.use_experience) @@ -371,6 +391,7 @@ def reply( self.plugin_pool = self.select_plugins_for_prompt(query) if self.config.use_experience: + self.load_experience() selected_experiences = self.experience_generator.retrieve_experience(query) else: selected_experiences = None diff --git a/taskweaver/ext_role/echo/echo.py b/taskweaver/ext_role/echo/echo.py index c7951a15..d82d02e1 100644 --- a/taskweaver/ext_role/echo/echo.py +++ b/taskweaver/ext_role/echo/echo.py @@ -2,6 +2,7 @@ from taskweaver.logging import TelemetryLogger from taskweaver.memory import Memory, Post +from taskweaver.memory.attachment import AttachmentType from taskweaver.module.event_emitter import SessionEventEmitter from taskweaver.module.tracing import Tracing from taskweaver.role import Role @@ -41,4 +42,9 @@ def reply(self, memory: Memory, **kwargs: ...) -> Post: self.config.decorator + last_post.message + self.config.decorator, ) + post_proxy.update_attachment( + type=AttachmentType.signal, + message="exp_filter:sub_exp", + ) + return post_proxy.end() diff --git a/taskweaver/memory/attachment.py b/taskweaver/memory/attachment.py index 183a5b70..4c298930 100644 --- a/taskweaver/memory/attachment.py +++ b/taskweaver/memory/attachment.py @@ -45,6 +45,10 @@ class AttachmentType(Enum): # board info board = "board" + # signal + signal = "signal" + exp_filter = "exp_filter" + @dataclass class Attachment: diff --git a/taskweaver/memory/experience.py b/taskweaver/memory/experience.py index b61a0954..70044175 100644 --- a/taskweaver/memory/experience.py +++ b/taskweaver/memory/experience.py @@ -1,7 +1,8 @@ import json import os +from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Set from injector import inject @@ -83,6 +84,11 @@ def __init__( "run `python -m experience_mgt --refresh` to refresh the experience." ) + self.sub_dir = None + + def set_sub_dir(self, sub_dir: str): + self.sub_dir = sub_dir + @staticmethod def _preprocess_conversation_data( conv_data: dict, @@ -109,7 +115,9 @@ def summarize_experience( exp_id: str, prompt: Optional[str] = None, ): - raw_exp_file_path = os.path.join(self.config.experience_dir, f"raw_exp_{exp_id}.yaml") + exp_dir = self.get_experience_dir() + + raw_exp_file_path = os.path.join(exp_dir, f"raw_exp_{exp_id}.yaml") conversation = read_yaml(raw_exp_file_path) conversation = self._preprocess_conversation_data(conversation) @@ -145,10 +153,12 @@ def refresh( self, prompt: Optional[str] = None, ): - if not os.path.exists(self.config.experience_dir): - raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.") + exp_dir = self.get_experience_dir() - exp_files = os.listdir(self.config.experience_dir) + if not os.path.exists(exp_dir): + raise ValueError(f"Experience directory {exp_dir} does not exist.") + + exp_files = os.listdir(exp_dir) raw_exp_ids = [ os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] @@ -176,10 +186,10 @@ def refresh( for idx, exp_id in enumerate(exp_ids): rebuild_flag = False exp_file_name = f"exp_{exp_id}.yaml" - if exp_file_name not in os.listdir(self.config.experience_dir): + if exp_file_name not in os.listdir(exp_dir): rebuild_flag = True else: - exp_file_path = os.path.join(self.config.experience_dir, exp_file_name) + exp_file_path = os.path.join(exp_dir, exp_file_name) experience = read_yaml(exp_file_path) if ( experience["embedding_model"] != self.llm_api.embedding_service.config.embedding_model @@ -194,13 +204,13 @@ def refresh( experience_text=summarized_experience, exp_id=exp_id, raw_experience_path=os.path.join( - self.config.experience_dir, + exp_dir, f"raw_exp_{exp_id}.yaml", ), ) elif exp_id in handcrafted_exp_ids: handcrafted_exp_file_path = os.path.join( - self.config.experience_dir, + exp_dir, f"handcrafted_exp_{exp_id}.yaml", ) experience_obj = Experience.from_dict(read_yaml(handcrafted_exp_file_path)) @@ -218,21 +228,21 @@ def refresh( for i, exp in enumerate(to_be_embedded): exp.embedding = exp_embeddings[i] exp.embedding_model = self.llm_api.embedding_service.config.embedding_model - experience_file_path = os.path.join(self.config.experience_dir, f"exp_{exp.exp_id}.yaml") + experience_file_path = os.path.join(exp_dir, f"exp_{exp.exp_id}.yaml") write_yaml(experience_file_path, exp.to_dict()) self.logger.info("Experience obj saved.") @tracing_decorator - def load_experience( - self, - ): - if not os.path.exists(self.config.experience_dir): - raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.") + def load_experience(self): + exp_dir = self.get_experience_dir() + + if not os.path.exists(exp_dir): + raise ValueError(f"Experience directory {exp_dir} does not exist.") original_exp_files = [ exp_file - for exp_file in os.listdir(self.config.experience_dir) + for exp_file in os.listdir(exp_dir) if exp_file.startswith("raw_exp_") or exp_file.startswith("handcrafted_exp_") ] exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files] @@ -245,8 +255,12 @@ def load_experience( return for exp_id in exp_ids: + exp_id_exists = exp_id in [exp.exp_id for exp in self.experience_list] + if exp_id_exists: + continue + exp_file = f"exp_{exp_id}.yaml" - exp_file_path = os.path.join(self.config.experience_dir, exp_file) + exp_file_path = os.path.join(exp_dir, exp_file) assert os.path.exists(exp_file_path), ( f"Experience {exp_file} not found. " + self.exception_message_for_refresh ) @@ -293,12 +307,18 @@ def retrieve_experience(self, user_query: str) -> List[Tuple[Experience, float]] return selected_experiences def _delete_exp_file(self, exp_file_name: str): - if exp_file_name in os.listdir(self.config.experience_dir): - os.remove(os.path.join(self.config.experience_dir, exp_file_name)) + exp_dir = self.get_experience_dir() + + if exp_file_name in os.listdir(exp_dir): + os.remove(os.path.join(exp_dir, exp_file_name)) self.logger.info(f"Experience {exp_file_name} deleted.") else: self.logger.info(f"Experience {exp_file_name} not found.") + def get_experience_dir(self): + return os.path.join(self.config.experience_dir, self.sub_dir) \ + if self.sub_dir else self.config.experience_dir + def delete_experience(self, exp_id: str): exp_file_name = f"exp_{exp_id}.yaml" self._delete_exp_file(exp_file_name) diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index 6c5241e1..fab5f167 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -50,6 +50,7 @@ def _configure(self) -> None: ) self.use_experience = self._get_bool("use_experience", False) + self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False) self.llm_alias = self._get_str("llm_alias", default="", required=False) @@ -100,18 +101,32 @@ def __init__( self.round_compressor = round_compressor self.compression_prompt_template = read_yaml(self.config.compression_prompt_path)["content"] - if self.config.use_experience: - assert experience_generator is not None, "Experience generator is required when use_experience is True" - self.experience_generator = experience_generator - self.experience_generator.refresh() - self.experience_generator.load_experience() - self.logger.info( - "Experience loaded successfully, " - "there are {} experiences".format(len(self.experience_generator.experience_list)), - ) + self.experience_generator = experience_generator + self.exp_loaded = False + if self.config.dynamic_experience_filter: + self.exp_filter_str = None + else: + # use the experience folder + self.exp_filter_str = "" self.logger.info("Planner initialized successfully") + def load_experience(self): + if self.exp_filter_str is None or self.exp_loaded: + return + self.experience_generator.set_sub_dir(self.exp_filter_str) + self.experience_generator.refresh() + self.experience_generator.load_experience() + self.logger.info( + "Experience loaded successfully, " + "there are {} experiences with filter [{}]".format( + len(self.experience_generator.experience_list), + self.exp_filter_str, + ) + ) + + self.exp_loaded = True + def compose_sys_prompt(self, context: str): worker_description = "" for alias, role in self.workers.items(): @@ -274,10 +289,18 @@ def reply( assert len(rounds) != 0, "No chat rounds found for planner" user_query = rounds[-1].user_query + + exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter) + if exp_filter: + self.exp_filter_str = exp_filter[0] + self.tracing.set_span_attribute("exp_filter", self.exp_filter_str) + self.tracing.set_span_attribute("user_query", user_query) self.tracing.set_span_attribute("use_experience", self.config.use_experience) + self.tracing.set_span_attribute("exp_filter", self.exp_filter_str) if self.config.use_experience: + self.load_experience() selected_experiences = self.experience_generator.retrieve_experience(user_query) else: selected_experiences = None diff --git a/taskweaver/session/session.py b/taskweaver/session/session.py index 0919eeb2..a94a55f4 100644 --- a/taskweaver/session/session.py +++ b/taskweaver/session/session.py @@ -4,10 +4,11 @@ from typing import Any, Dict, List, Literal, Optional from injector import Injector, inject +from scipy.stats import logser from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger -from taskweaver.memory import Memory, Post, Round +from taskweaver.memory import Memory, Post, Round, Attachment from taskweaver.memory.attachment import AttachmentType from taskweaver.module.event_emitter import SessionEventEmitter, SessionEventHandler from taskweaver.module.tracing import Tracing, tracing_decorator, tracing_decorator_non_class @@ -88,6 +89,7 @@ def __init__( self.memory = Memory(session_id=self.session_id) self.session_var: Dict[str, str] = {} + self.session_signal: Dict[AttachmentType, str] = {} self.event_emitter = self.session_injector.get(SessionEventEmitter) self.session_injector.binder.bind(SessionEventEmitter, self.event_emitter) @@ -100,11 +102,21 @@ def __init__( if role_name not in role_registry.get_role_name_list(): raise ValueError(f"Unknown role {role_name}") role_entry = self.role_registry.get(role_name) - role_instance = self.session_injector.create_object(role_entry.module, {"role_entry": role_entry}) + role_instance = self.session_injector.create_object( + role_entry.module, + { + "role_entry": role_entry + } + ) self.worker_instances[role_instance.get_alias()] = role_instance if "planner" in self.config.roles: - self.planner = self.session_injector.create_object(Planner, {"workers": self.worker_instances}) + self.planner = self.session_injector.create_object( + Planner, + { + "workers": self.worker_instances + } + ) self.session_injector.binder.bind(Planner, self.planner) self.max_internal_chat_round_num = self.config.max_internal_chat_round_num @@ -163,6 +175,16 @@ def _send_text_message( @tracing_decorator_non_class def _send_message(recipient: str, post: Post) -> Post: + # add session signal to the post + if self.session_signal: + for signal_type in self.session_signal: + post.add_attachment( + Attachment.create( + type=signal_type, + content=self.session_signal[signal_type], + ) + ) + self.tracing.set_span_attribute("in.from", post.send_from) self.tracing.set_span_attribute("in.recipient", recipient) self.tracing.set_span_attribute("in.message", post.message) @@ -191,7 +213,17 @@ def _send_message(recipient: str, post: Post) -> Post: board_attachment = reply_post.get_attachment(AttachmentType.board) if len(board_attachment) > 0: - chat_round.write_board(reply_post.send_from, reply_post.get_attachment(AttachmentType.board)[0]) + chat_round.write_board( + reply_post.send_from, + reply_post.get_attachment(AttachmentType.board)[0] + ) + + signal_attachments = reply_post.get_attachment(AttachmentType.signal) + for signal in signal_attachments: + signal_type, signal_value = signal.split(":") + self.logger.info(f"Session signal: {signal_type}={signal_value}") + # signal_type must be in AttachmentType + self.session_signal[AttachmentType(signal_type)] = signal_value return reply_post diff --git a/tests/unit_tests/test_experience.py b/tests/unit_tests/test_experience.py index 46a0d6d2..1db44ed1 100644 --- a/tests/unit_tests/test_experience.py +++ b/tests/unit_tests/test_experience.py @@ -10,7 +10,7 @@ IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" -@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.") +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.") def test_experience_retrieval(): app_injector = Injector([LoggingModule]) app_config = AppConfigSource(