From 2d8086d6a678eb190f88b77ce3504ad4b3126af8 Mon Sep 17 00:00:00 2001 From: Liqun Li Date: Thu, 26 Sep 2024 18:20:54 +0800 Subject: [PATCH] Dynamic loading examples (#420) 1. support dynamically loading examples for different roles 2. move loading example to role base class 3. rename `codeinterpreter_examples` to `code_generator_examples` in project folder; need manual migration --- .../example1-codeinterpreter.yaml | 0 .../example2-codeinterpreter.yaml | 0 .../planner_examples/example-planner-2.yaml | 0 .../example-planner-echo.yaml | 0 .../planner_examples/example-planner.yaml | 0 .../code_interpreter/code_generator.py | 41 +----- taskweaver/memory/experience.py | 20 +-- taskweaver/memory/type_vars.py | 2 +- taskweaver/misc/example.py | 7 + taskweaver/planner/planner.py | 51 ++----- taskweaver/role/role.py | 139 +++++++++++++----- .../planner_examples/sub/example-planner.yaml | 43 ++++++ tests/unit_tests/test_code_generator.py | 22 ++- tests/unit_tests/test_example.py | 40 ++++- tests/unit_tests/test_planner.py | 2 + tests/unit_tests/test_role.py | 34 ++++- website/blog/experience.md | 4 +- website/docs/concepts/project.md | 5 +- website/docs/quickstart.md | 5 +- 19 files changed, 263 insertions(+), 152 deletions(-) rename project/{codeinterpreter_examples => examples/code_generator_examples}/example1-codeinterpreter.yaml (100%) rename project/{codeinterpreter_examples => examples/code_generator_examples}/example2-codeinterpreter.yaml (100%) rename project/{ => examples}/planner_examples/example-planner-2.yaml (100%) rename project/{ => examples}/planner_examples/example-planner-echo.yaml (100%) rename project/{ => examples}/planner_examples/example-planner.yaml (100%) create mode 100644 tests/unit_tests/data/examples/planner_examples/sub/example-planner.yaml diff --git a/project/codeinterpreter_examples/example1-codeinterpreter.yaml b/project/examples/code_generator_examples/example1-codeinterpreter.yaml similarity index 100% rename from project/codeinterpreter_examples/example1-codeinterpreter.yaml rename to project/examples/code_generator_examples/example1-codeinterpreter.yaml diff --git a/project/codeinterpreter_examples/example2-codeinterpreter.yaml b/project/examples/code_generator_examples/example2-codeinterpreter.yaml similarity index 100% rename from project/codeinterpreter_examples/example2-codeinterpreter.yaml rename to project/examples/code_generator_examples/example2-codeinterpreter.yaml diff --git a/project/planner_examples/example-planner-2.yaml b/project/examples/planner_examples/example-planner-2.yaml similarity index 100% rename from project/planner_examples/example-planner-2.yaml rename to project/examples/planner_examples/example-planner-2.yaml diff --git a/project/planner_examples/example-planner-echo.yaml b/project/examples/planner_examples/example-planner-echo.yaml similarity index 100% rename from project/planner_examples/example-planner-echo.yaml rename to project/examples/planner_examples/example-planner-echo.yaml diff --git a/project/planner_examples/example-planner.yaml b/project/examples/planner_examples/example-planner.yaml similarity index 100% rename from project/planner_examples/example-planner.yaml rename to project/examples/planner_examples/example-planner.yaml diff --git a/taskweaver/code_interpreter/code_interpreter/code_generator.py b/taskweaver/code_interpreter/code_interpreter/code_generator.py index b456b589..795b2f15 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_generator.py +++ b/taskweaver/code_interpreter/code_interpreter/code_generator.py @@ -1,7 +1,7 @@ import datetime import json import os -from typing import List, Optional, Tuple +from typing import List, Optional from injector import inject @@ -9,11 +9,10 @@ from taskweaver.llm import LLMApi from taskweaver.llm.util import ChatMessageType, format_chat_message from taskweaver.logging import TelemetryLogger -from taskweaver.memory import Attachment, Conversation, Memory, Post, Round, RoundCompressor +from taskweaver.memory import Attachment, Memory, Post, Round, RoundCompressor from taskweaver.memory.attachment import AttachmentType -from taskweaver.memory.experience import Experience, ExperienceGenerator +from taskweaver.memory.experience import ExperienceGenerator from taskweaver.memory.plugin import PluginEntry, PluginRegistry -from taskweaver.misc.example import load_examples from taskweaver.module.event_emitter import PostEventProxy, SessionEventEmitter from taskweaver.module.tracing import Tracing, tracing_decorator from taskweaver.role import PostTranslator, Role @@ -26,7 +25,6 @@ def _configure(self) -> None: self._set_name("code_generator") self.role_name = self._get_str("role_name", "ProgramApe") self.load_plugin = self._get_bool("load_plugin", True) - self.load_example = self._get_bool("load_example", True) self.prompt_file_path = self._get_path( "prompt_file_path", os.path.join( @@ -34,13 +32,6 @@ def _configure(self) -> None: "code_generator_prompt.yaml", ), ) - self.example_base_path = self._get_path( - "example_base_path", - os.path.join( - self.src.app_base_path, - "codeinterpreter_examples", - ), - ) self.prompt_compression = self._get_bool("prompt_compression", False) self.compression_prompt_path = self._get_path( "compression_prompt_path", @@ -89,7 +80,6 @@ def __init__( self.query_requirements_template = self.prompt_data["requirements"] self.response_json_schema = json.loads(self.prompt_data["response_json_schema"]) - self.examples = None self.code_verification_on: bool = False self.allowed_modules: List[str] = [] @@ -157,12 +147,10 @@ def compose_prompt( self, rounds: List[Round], plugins: List[PluginEntry], - selected_experiences: Optional[List[Tuple[Experience, float]]] = None, planning_enrichments: Optional[List[str]] = None, ) -> List[ChatMessageType]: experiences = self.format_experience( template=self.prompt_data["experience_instruction"], - experiences=selected_experiences, ) chat_history = [ @@ -172,8 +160,6 @@ def compose_prompt( ), ] - if self.examples is None: - self.examples = self.load_examples() for i, example in enumerate(self.examples): chat_history.extend( self.compose_conversation(example.rounds, example.plugins, add_requirements=False), @@ -358,21 +344,14 @@ def reply( if self.config.enable_auto_plugin_selection: self.plugin_pool = self.select_plugins_for_prompt(query) - exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path") - - if exp_sub_paths: - self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths)) - exp_sub_path = exp_sub_paths[0].content - else: - exp_sub_path = "" - selected_experiences = self.load_experience(query=query, sub_path=exp_sub_path) + self.role_load_experience(query=query, memory=memory) + self.role_load_example(memory=memory, role_set={self.alias, "Planner"}) planning_enrichments = memory.get_shared_memory_entries(entry_type="plan") prompt = self.compose_prompt( rounds, self.plugin_pool, - selected_experiences, planning_enrichments=[pe.content for pe in planning_enrichments], ) @@ -440,16 +419,6 @@ def format_plugins( ) return "" - def load_examples( - self, - ) -> List[Conversation]: - if self.config.load_example: - return load_examples( - folder=self.config.example_base_path, - role_set={self.alias, "Planner"}, - ) - return [] - def get_plugin_pool(self) -> List[PluginEntry]: return self.plugin_pool diff --git a/taskweaver/memory/experience.py b/taskweaver/memory/experience.py index 0940362f..455d330a 100644 --- a/taskweaver/memory/experience.py +++ b/taskweaver/memory/experience.py @@ -242,7 +242,7 @@ def load_experience(self): exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files] if len(exp_ids) == 0: self.logger.warning( - "No experience found." + "No experience found.", ) return @@ -253,18 +253,14 @@ def load_experience(self): exp_file = f"exp_{exp_id}.yaml" exp_file_path = os.path.join(exp_dir, exp_file) - assert os.path.exists(exp_file_path), ( - f"Experience {exp_file} not found. " - ) + assert os.path.exists(exp_file_path), f"Experience {exp_file} not found. " experience = read_yaml(exp_file_path) - assert len(experience["embedding"]) > 0, ( - f"Experience {exp_file} has no embedding." - ) - assert experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model, ( - f"Experience {exp_file} has different embedding model." - ) + assert len(experience["embedding"]) > 0, f"Experience {exp_file} has no embedding." + assert ( + experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model + ), f"Experience {exp_file} has different embedding model." self.experience_list.append(Experience(**experience)) @@ -326,13 +322,13 @@ def delete_handcrafted_experience(self, exp_id: str): @staticmethod def format_experience_in_prompt( prompt_template: str, - selected_experiences: Optional[List[Tuple[Experience, float]]] = None, + selected_experiences: Optional[List[Experience,]] = None, ): if selected_experiences is not None and len(selected_experiences) > 0: return prompt_template.format( experiences="===================\n" + "\n===================\n".join( - [exp.experience_text for exp, _ in selected_experiences], + [exp.experience_text for exp in selected_experiences], ), ) else: diff --git a/taskweaver/memory/type_vars.py b/taskweaver/memory/type_vars.py index ffbf31bd..a12cc400 100644 --- a/taskweaver/memory/type_vars.py +++ b/taskweaver/memory/type_vars.py @@ -2,5 +2,5 @@ RoleName = str RoundState = Literal["finished", "failed", "created"] -SharedMemoryEntryType = Literal["plan", "experience_sub_path"] +SharedMemoryEntryType = Literal["plan", "experience_sub_path", "example_sub_path"] SharedMemoryEntryScope = Literal["round", "conversation"] diff --git a/taskweaver/misc/example.py b/taskweaver/misc/example.py index 026e9a24..92f97947 100644 --- a/taskweaver/misc/example.py +++ b/taskweaver/misc/example.py @@ -7,6 +7,7 @@ def load_examples( folder: str, + sub_path: Optional[str] = None, role_set: Optional[Set[str]] = None, ) -> List[Conversation]: """ @@ -14,8 +15,14 @@ def load_examples( Args: folder: the folder path. + sub_path: the sub-folder path. role_set: the roles should be included in the examples. """ + if sub_path: + folder = path.join(folder, sub_path) + if not path.exists(folder): + raise FileNotFoundError(f"Folder {folder} does not exist.") + example_file_list: List[str] = glob.glob(path.join(folder, "*.yaml")) example_conv_pool: List[Conversation] = [] for yaml_path in example_file_list: diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index 7823393f..5e7ac13a 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -3,18 +3,17 @@ import os import types from json import JSONDecodeError -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional from injector import inject from taskweaver.llm import LLMApi from taskweaver.llm.util import ChatMessageType, format_chat_message from taskweaver.logging import TelemetryLogger -from taskweaver.memory import Conversation, Memory, Post, Round, RoundCompressor +from taskweaver.memory import Memory, Post, Round, RoundCompressor from taskweaver.memory.attachment import AttachmentType -from taskweaver.memory.experience import Experience, ExperienceGenerator +from taskweaver.memory.experience import ExperienceGenerator from taskweaver.memory.memory import SharedMemoryEntry -from taskweaver.misc.example import load_examples from taskweaver.module.event_emitter import SessionEventEmitter from taskweaver.module.tracing import Tracing, tracing_decorator from taskweaver.role import PostTranslator, Role @@ -25,8 +24,6 @@ class PlannerConfig(RoleConfig): def _configure(self) -> None: self._set_name("planner") - app_dir = self.src.app_base_path - self.use_example = self._get_bool("use_example", True) self.prompt_file_path = self._get_path( "prompt_file_path", os.path.join( @@ -34,13 +31,6 @@ def _configure(self) -> None: "planner_prompt.yaml", ), ) - self.example_base_path = self._get_path( - "example_base_path", - os.path.join( - app_dir, - "planner_examples", - ), - ) self.prompt_compression = self._get_bool("prompt_compression", False) self.compression_prompt_path = self._get_path( "compression_prompt_path", @@ -82,9 +72,6 @@ def __init__( self.prompt_data = read_yaml(self.config.prompt_file_path) - if self.config.use_example: - self.examples = self.get_examples() - self.instruction_template = self.prompt_data["instruction_template"] self.response_json_schema = json.loads(self.prompt_data["response_json_schema"]) @@ -211,11 +198,9 @@ def get_env_context(self) -> str: def compose_prompt( self, rounds: List[Round], - selected_experiences: Optional[List[Tuple[Experience, float]]] = None, ) -> List[ChatMessageType]: experiences = self.format_experience( template=self.prompt_data["experience_instruction"], - experiences=selected_experiences, ) chat_history = [ @@ -225,12 +210,11 @@ def compose_prompt( ), ] - if self.config.use_example and len(self.examples) != 0: - for conv_example in self.examples: - conv_example_in_prompt = self.compose_conversation_for_prompt( - conv_example.rounds, - ) - chat_history += conv_example_in_prompt + for conv_example in self.examples: + conv_example_in_prompt = self.compose_conversation_for_prompt( + conv_example.rounds, + ) + chat_history += conv_example_in_prompt summary = None if self.config.prompt_compression and self.round_compressor is not None: @@ -266,19 +250,13 @@ def reply( self.tracing.set_span_attribute("user_query", user_query) self.tracing.set_span_attribute("use_experience", self.config.use_experience) - exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path") - - if exp_sub_paths: - self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths)) - exp_sub_path = exp_sub_paths[0].content - else: - exp_sub_path = "" - selected_experiences = self.load_experience(query=user_query, sub_path=exp_sub_path) + self.role_load_experience(query=user_query, memory=memory) + self.role_load_example(role_set=set(self.recipient_alias_set) | {self.alias, "User"}, memory=memory) post_proxy = self.event_emitter.create_post_proxy(self.alias) post_proxy.update_status("composing prompt") - chat_history = self.compose_prompt(rounds, selected_experiences) + chat_history = self.compose_prompt(rounds) def check_post_validity(post: Post): missing_elements: List[str] = [] @@ -392,10 +370,3 @@ def stream_filter(s: Iterable[ChatMessageType]): self.tracing.set_span_attribute("out.attachments", str(reply_post.attachment_list)) return reply_post - - def get_examples(self) -> List[Conversation]: - example_conv_list = load_examples( - self.config.example_base_path, - role_set=set(self.recipient_alias_set) | {self.alias, "User"}, - ) - return example_conv_list diff --git a/taskweaver/role/role.py b/taskweaver/role/role.py index 64c48ff5..21872ede 100644 --- a/taskweaver/role/role.py +++ b/taskweaver/role/role.py @@ -2,16 +2,17 @@ import os.path from dataclasses import dataclass from datetime import timedelta -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Set, Tuple, Union from injector import Module, inject, provider from taskweaver.config.config_mgt import AppConfigSource from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger -from taskweaver.memory import Memory, Post +from taskweaver.memory import Conversation, Memory, Post from taskweaver.memory.experience import Experience, ExperienceGenerator from taskweaver.misc.component_registry import ComponentRegistry +from taskweaver.misc.example import load_examples from taskweaver.module.event_emitter import SessionEventEmitter from taskweaver.module.tracing import Tracing from taskweaver.utils import import_module, read_yaml @@ -62,6 +63,23 @@ def __init__(self, src: AppConfigSource) -> None: False, ) + self.use_example = self._get_bool( + "use_example", + True, + ) + self.example_base_path = self._get_path( + "example_base_path", + os.path.join( + self.src.app_base_path, + "examples", + f"{self.name}_examples", + ), + ) + self.dynamic_example_sub_path = self._get_bool( + "dynamic_example_sub_path", + False, + ) + self._configure() def _set_role_name(self): @@ -103,9 +121,13 @@ def __init__( self.alias: str = self.role_entry.alias if self.role_entry else "" self.intro: str = self.role_entry.intro if self.role_entry else "" + self.experiences: List[Experience] = [] self.experience_generator: Optional[ExperienceGenerator] = None self.experience_loaded_from: Optional[str] = None + self.examples: List[Conversation] = [] + self.example_loaded_from: Optional[str] = None + def get_intro(self) -> str: return self.intro @@ -121,60 +143,111 @@ def reply(self, memory: Memory, **kwargs: ...) -> Post: def close(self) -> None: self.logger.info(f"{self.alias} closed successfully") - def load_experience( + def format_experience( + self, + template: str, + ) -> str: + return ( + self.experience_generator.format_experience_in_prompt(template, self.experiences) + if self.config.use_experience + else "" + ) + + def role_load_experience( self, query: str, - sub_path: str = "", - ) -> List[Tuple[Experience, float]]: + memory: Optional[Memory] = None, + ) -> None: + if not self.config.use_experience: + self.experiences = [] + return + if self.experience_generator is None: - raise ValueError("Experience generator is not initialized.") + raise ValueError( + "Experience generator is not initialized. Each role instance should have its own generator.", + ) - if self.config.use_experience: - if not self.config.dynamic_experience_sub_path: - self._load_experience() - elif sub_path: - self._load_experience(sub_path=sub_path) + experience_sub_path = "" + if self.config.dynamic_experience_sub_path: + assert memory is not None, "Memory should be provided when dynamic_experience_sub_path is True" + experience_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path") + if experience_sub_paths: + self.tracing.set_span_attribute("experience_sub_path", str(experience_sub_paths)) + # todo: handle multiple experience sub paths + experience_sub_path = experience_sub_paths[0].content else: - # if sub_path is empty, experience should not have been loaded - assert self.experience_loaded_from is None, "sub_path is empty when dynamic_experience_sub_path is True" - - return self.experience_generator.retrieve_experience(query) - else: - return [] + self.logger.info("No experience sub path found in memory.") + self.experiences = [] + return - def _load_experience(self, sub_path: str = "") -> None: - load_from = os.path.join(self.config.experience_dir, sub_path) + load_from = os.path.join(self.config.experience_dir, experience_sub_path) if self.experience_loaded_from is None or self.experience_loaded_from != load_from: self.experience_loaded_from = load_from self.experience_generator.set_experience_dir(self.config.experience_dir) - self.experience_generator.set_sub_path(sub_path) + self.experience_generator.set_sub_path(experience_sub_path) self.experience_generator.refresh() self.experience_generator.load_experience() self.logger.info( "Experience loaded successfully for {}, there are {} experiences with filter [{}]".format( self.alias, len(self.experience_generator.experience_list), - sub_path, + experience_sub_path, ), ) else: self.logger.info(f"Experience already loaded from {load_from}.") - def format_experience( + experiences = self.experience_generator.retrieve_experience(query) + self.logger.info(f"Retrieved {len(experiences)} experiences for query [{query}]") + self.experiences = [exp for exp, _ in experiences] + + # todo: `role_load_example` is similar to `role_load_experience`, consider refactoring + def role_load_example( self, - template: str, - experiences: Optional[List[Tuple[Experience, float]]], - ) -> str: - experiences_str = ( - self.experience_generator.format_experience_in_prompt( - template, - experiences, + role_set: Set[str], + memory: Optional[Memory] = None, + ) -> None: + if not self.config.use_example: + self.examples = [] + return + + if not os.path.exists(self.config.example_base_path): + raise FileNotFoundError( + f"The default example base path {self.config.example_base_path} does not exist." + "The original example base paths have been changed to `examples` folder." + "Please migrate the examples to the new base path.", ) - if self.config.use_experience - else "" - ) - return experiences_str + example_sub_path = "" + if self.config.dynamic_example_sub_path: + assert memory is not None, "Memory should be provided when dynamic_example_sub_path is True" + example_sub_paths = memory.get_shared_memory_entries(entry_type="example_sub_path") + if example_sub_paths: + self.tracing.set_span_attribute("example_sub_path", str(example_sub_paths)) + # todo: handle multiple sub paths + example_sub_path = example_sub_paths[0].content + else: + self.logger.info("No example sub path found in memory.") + self.examples = [] + return + + load_from = os.path.join(self.config.example_base_path, example_sub_path) + if self.example_loaded_from is None or self.example_loaded_from != load_from: + self.example_loaded_from = load_from + self.examples = load_examples( + folder=self.config.example_base_path, + sub_path=example_sub_path, + role_set=role_set, + ) + self.logger.info( + "Example loaded successfully for {}, there are {} examples with filter [{}]".format( + self.alias, + len(self.examples), + example_sub_path, + ), + ) + else: + self.logger.info(f"Example already loaded from {load_from}.") class RoleModuleConfig(ModuleConfig): diff --git a/tests/unit_tests/data/examples/planner_examples/sub/example-planner.yaml b/tests/unit_tests/data/examples/planner_examples/sub/example-planner.yaml new file mode 100644 index 00000000..525463ab --- /dev/null +++ b/tests/unit_tests/data/examples/planner_examples/sub/example-planner.yaml @@ -0,0 +1,43 @@ +enabled: True +rounds: + - user_query: count the rows of /home/data.csv + state: created + post_list: + - message: count the rows of /home/data.csv + send_from: User + send_to: Planner + attachment_list: + - message: Please load the data file /home/data.csv and count the rows of the loaded data + send_from: Planner + send_to: CodeInterpreter + attachment_list: + - type: init_plan + content: |- + 1. load the data file + 2. count the rows of the loaded data + 3. report the result to the user + - type: plan + content: |- + 1. instruct CodeInterpreter to load the data file and count the rows of the loaded data + 2. report the result to the user + - type: current_plan_step + content: 1. instruct CodeInterpreter to load the data file and count the rows of the loaded data + - message: Load the data file /home/data.csv successfully and there are 100 rows in the data file + send_from: CodeInterpreter + send_to: Planner + attachment_list: + - message: The data file /home/data.csv is loaded and there are 100 rows in the data file + send_from: Planner + send_to: User + attachment_list: + - type: init_plan + content: |- + 1. load the data file + 2. count the rows of the loaded data + 3. report the result to the user + - type: plan + content: |- + 1. instruct CodeInterpreter to load the data file and count the rows of the loaded data + 2. report the result to the user + - type: current_plan_step + content: 2. report the result to the user \ No newline at end of file diff --git a/tests/unit_tests/test_code_generator.py b/tests/unit_tests/test_code_generator.py index 666a020c..5b595f48 100644 --- a/tests/unit_tests/test_code_generator.py +++ b/tests/unit_tests/test_code_generator.py @@ -464,6 +464,8 @@ def test_compose_prompt_with_not_plugin_only(): memory = Memory(session_id="session-1") memory.conversation.add_round(round1) + code_generator.role_load_example({"Planner", "CodeInterpreter"}, memory) + messages = code_generator.compose_prompt( rounds=memory.conversation.rounds, plugins=code_generator.get_plugin_pool(), @@ -620,26 +622,20 @@ def test_compose_with_shared_plan(): memory.conversation.add_round(round1) selected_experiences = [ - ( - Experience( - exp_id="exp-1", - experience_text="this is a test experience", - ), - 0.3, + Experience( + exp_id="exp-1", + experience_text="this is a test experience", ), - ( - Experience( - exp_id="exp-2", - experience_text="this is another test experience", - ), - 0.2, + Experience( + exp_id="exp-2", + experience_text="this is another test experience", ), ] + code_generator.experiences = selected_experiences messages = code_generator.compose_prompt( rounds=memory.conversation.rounds, plugins=code_generator.get_plugin_pool(), - selected_experiences=selected_experiences, planning_enrichments=[ "shared_memory_entry1", "shared_memory_entry2", diff --git a/tests/unit_tests/test_example.py b/tests/unit_tests/test_example.py index 2ab016d0..4bbe3697 100644 --- a/tests/unit_tests/test_example.py +++ b/tests/unit_tests/test_example.py @@ -10,20 +10,48 @@ def test_load_examples(): "examples", "planner_examples", ) - examples = load_examples(example_path, {"Planner", "User", "CodeInterpreter"}) + sub_path = "" + examples = load_examples(example_path, sub_path, {"Planner", "User", "CodeInterpreter"}) assert len(examples) == 1 - examples = load_examples(example_path) + examples = load_examples(example_path, sub_path) assert len(examples) == 1 - examples = load_examples(example_path, {"Planner"}) + examples = load_examples(example_path, sub_path, {"Planner"}) assert len(examples) == 0 - examples = load_examples(example_path, {"User"}) + examples = load_examples(example_path, sub_path, {"User"}) assert len(examples) == 0 - examples = load_examples(example_path, {"Planner", "User", "Other"}) + examples = load_examples(example_path, sub_path, {"Planner", "User", "Other"}) assert len(examples) == 0 - examples = load_examples(example_path, {"Planner", "User", "CodeInterpreter", "Other"}) + examples = load_examples(example_path, sub_path, {"Planner", "User", "CodeInterpreter", "Other"}) + assert len(examples) == 1 + + +def test_load_sub_examples(): + example_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "data", + "examples", + "planner_examples", + ) + sub_path = "sub" + examples = load_examples(example_path, sub_path, {"Planner", "User", "CodeInterpreter"}) + assert len(examples) == 1 + + examples = load_examples(example_path, sub_path) + assert len(examples) == 1 + + examples = load_examples(example_path, sub_path, {"Planner"}) + assert len(examples) == 0 + + examples = load_examples(example_path, sub_path, {"User"}) + assert len(examples) == 0 + + examples = load_examples(example_path, sub_path, {"Planner", "User", "Other"}) + assert len(examples) == 0 + + examples = load_examples(example_path, sub_path, {"Planner", "User", "CodeInterpreter", "Other"}) assert len(examples) == 1 diff --git a/tests/unit_tests/test_planner.py b/tests/unit_tests/test_planner.py index 447feca0..0835ac1d 100644 --- a/tests/unit_tests/test_planner.py +++ b/tests/unit_tests/test_planner.py @@ -276,6 +276,8 @@ def test_compose_example_for_prompt(): memory = Memory(session_id="session-1") memory.conversation.add_round(round1) + planner.role_load_example({"Planner", "CodeInterpreter", "User"}, memory) + messages = planner.compose_prompt(rounds=memory.conversation.rounds) assert messages[0]["role"] == "system" diff --git a/tests/unit_tests/test_role.py b/tests/unit_tests/test_role.py index f02772e2..18672d76 100644 --- a/tests/unit_tests/test_role.py +++ b/tests/unit_tests/test_role.py @@ -5,6 +5,8 @@ from taskweaver.config.config_mgt import AppConfigSource from taskweaver.logging import LoggingModule +from taskweaver.memory import Attachment, Memory, Post, Round, SharedMemoryEntry +from taskweaver.memory.attachment import AttachmentType from taskweaver.memory.experience import ExperienceGenerator from taskweaver.memory.plugin import PluginModule from taskweaver.role import Role @@ -34,7 +36,7 @@ def test_role_load_experience(): role.experience_generator = app_injector.create_object(ExperienceGenerator) - role.load_experience("test") + role.role_load_experience("test") assert len(role.experience_generator.experience_list) == 1 @@ -61,13 +63,35 @@ def test_role_load_experience_sub_path(): role.experience_generator = app_injector.create_object(ExperienceGenerator) - role.load_experience("test") + memory = Memory(session_id="session-1") + + role.role_load_experience("test", memory=memory) assert len(role.experience_generator.experience_list) == 0 - role.load_experience("test", "sub_path") + post1 = Post.create( + message="create a dataframe", + send_from="Planner", + send_to="CodeInterpreter", + attachment_list=[ + Attachment.create( + type=AttachmentType.shared_memory_entry, + content="", + extra=SharedMemoryEntry.create( + type="experience_sub_path", + content="sub_path", + scope="conversation", + ), + ), + ], + ) + round1 = Round.create(user_query="hello", id="round-1") + round1.add_post(post1) + memory.conversation.add_round(round1) + + role.role_load_experience("test", memory=memory) assert len(role.experience_generator.experience_list) == 1 try: - role.load_experience("test") + role.role_load_experience("test") except AssertionError as e: - assert str(e) == "sub_path is empty when dynamic_experience_sub_path is True" + assert str(e) == "Memory should be provided when dynamic_experience_sub_path is True" diff --git a/website/blog/experience.md b/website/blog/experience.md index 876186c6..831f281a 100644 --- a/website/blog/experience.md +++ b/website/blog/experience.md @@ -9,7 +9,7 @@ In this blog post, we discuss more advanced topics about the experience module o Every role in TaskWeaver can configure its own experience directory, which can be configured by setting the `role_name.experience_dir` field in the project configuration file. For the `Planner` and `CodeInterpreter` roles, you can configure the experience directory -by setting the `planner.experience_dir` and `code_interpreter.experience_dir` fields respectively. +by setting the `planner.experience_dir` and `code_generator.experience_dir` fields respectively. The default experience directory is `experience` in the project directory. @@ -106,7 +106,7 @@ if exp_sub_paths: exp_sub_path = exp_sub_paths[0].content else: exp_sub_path = "" -selected_experiences = self.load_experience(query=query, sub_path=exp_sub_path) +selected_experiences = self.role_load_experience(query=query, sub_path=exp_sub_path) ``` :::tip diff --git a/website/docs/concepts/project.md b/website/docs/concepts/project.md index 2ea9e27b..585f42e4 100644 --- a/website/docs/concepts/project.md +++ b/website/docs/concepts/project.md @@ -8,9 +8,10 @@ The following is a typical project directory structure: 📦project ┣ 📜taskweaver_config.json # the project configuration file for TaskWeaver ┣ 📂plugins # the folder to store plugins - ┣ 📂planner_examples # the folder to store planner examples - ┣ 📂codeinterpreter_examples # the folder to store code interpreter examples ┣ 📂logs # the folder to store logs, will be generated after program starts + ┣ 📂examples + ┣ 📂 planner_examples # the folder to store planner examples + ┗ 📂 code_generator_examples # the folder to store code generator examples ┗ 📂workspace # the directory stores session data, will be generated after program starts ┗ 📂 session_id ┣ 📂ces # the folder used by the code execution service diff --git a/website/docs/quickstart.md b/website/docs/quickstart.md index 2a30dbf9..e5b89ad8 100644 --- a/website/docs/quickstart.md +++ b/website/docs/quickstart.md @@ -24,9 +24,10 @@ A project directory typically contains the following files and folders: 📦project ┣ 📜taskweaver_config.json # the project configuration file for TaskWeaver ┣ 📂plugins # the folder to store plugins - ┣ 📂planner_examples # the folder to store planner examples - ┣ 📂codeinterpreter_examples # the folder to store code interpreter examples ┣ 📂logs # the folder to store logs, will be generated after program starts + ┣ 📂examples + ┣ 📂 planner_examples # the folder to store planner examples + ┗ 📂 code_generator_examples # the folder to store code generator examples ┗ 📂workspace # the directory stores session data, will be generated after program starts ┗ 📂 session_id ┣ 📂ces # the folder used by the code execution service