Skip to content

Commit

Permalink
Dynamic loading examples (#420)
Browse files Browse the repository at this point in the history
1. support dynamically loading examples for different roles
2. move loading example to role base class
3. rename `codeinterpreter_examples` to `code_generator_examples` in
project folder; need manual migration
  • Loading branch information
liqul authored Sep 26, 2024
1 parent 11d13c4 commit 2d8086d
Show file tree
Hide file tree
Showing 19 changed files with 263 additions and 152 deletions.
41 changes: 5 additions & 36 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import datetime
import json
import os
from typing import List, Optional, Tuple
from typing import List, Optional

from injector import inject

from taskweaver.code_interpreter.plugin_selection import PluginSelector, SelectedPluginPool
from taskweaver.llm import LLMApi
from taskweaver.llm.util import ChatMessageType, format_chat_message
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Attachment, Conversation, Memory, Post, Round, RoundCompressor
from taskweaver.memory import Attachment, Memory, Post, Round, RoundCompressor
from taskweaver.memory.attachment import AttachmentType
from taskweaver.memory.experience import Experience, ExperienceGenerator
from taskweaver.memory.experience import ExperienceGenerator
from taskweaver.memory.plugin import PluginEntry, PluginRegistry
from taskweaver.misc.example import load_examples
from taskweaver.module.event_emitter import PostEventProxy, SessionEventEmitter
from taskweaver.module.tracing import Tracing, tracing_decorator
from taskweaver.role import PostTranslator, Role
Expand All @@ -26,21 +25,13 @@ def _configure(self) -> None:
self._set_name("code_generator")
self.role_name = self._get_str("role_name", "ProgramApe")
self.load_plugin = self._get_bool("load_plugin", True)
self.load_example = self._get_bool("load_example", True)
self.prompt_file_path = self._get_path(
"prompt_file_path",
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"code_generator_prompt.yaml",
),
)
self.example_base_path = self._get_path(
"example_base_path",
os.path.join(
self.src.app_base_path,
"codeinterpreter_examples",
),
)
self.prompt_compression = self._get_bool("prompt_compression", False)
self.compression_prompt_path = self._get_path(
"compression_prompt_path",
Expand Down Expand Up @@ -89,7 +80,6 @@ def __init__(
self.query_requirements_template = self.prompt_data["requirements"]
self.response_json_schema = json.loads(self.prompt_data["response_json_schema"])

self.examples = None
self.code_verification_on: bool = False
self.allowed_modules: List[str] = []

Expand Down Expand Up @@ -157,12 +147,10 @@ def compose_prompt(
self,
rounds: List[Round],
plugins: List[PluginEntry],
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
planning_enrichments: Optional[List[str]] = None,
) -> List[ChatMessageType]:
experiences = self.format_experience(
template=self.prompt_data["experience_instruction"],
experiences=selected_experiences,
)

chat_history = [
Expand All @@ -172,8 +160,6 @@ def compose_prompt(
),
]

if self.examples is None:
self.examples = self.load_examples()
for i, example in enumerate(self.examples):
chat_history.extend(
self.compose_conversation(example.rounds, example.plugins, add_requirements=False),
Expand Down Expand Up @@ -358,21 +344,14 @@ def reply(
if self.config.enable_auto_plugin_selection:
self.plugin_pool = self.select_plugins_for_prompt(query)

exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")

if exp_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths))
exp_sub_path = exp_sub_paths[0].content
else:
exp_sub_path = ""
selected_experiences = self.load_experience(query=query, sub_path=exp_sub_path)
self.role_load_experience(query=query, memory=memory)
self.role_load_example(memory=memory, role_set={self.alias, "Planner"})

planning_enrichments = memory.get_shared_memory_entries(entry_type="plan")

prompt = self.compose_prompt(
rounds,
self.plugin_pool,
selected_experiences,
planning_enrichments=[pe.content for pe in planning_enrichments],
)

Expand Down Expand Up @@ -440,16 +419,6 @@ def format_plugins(
)
return ""

def load_examples(
self,
) -> List[Conversation]:
if self.config.load_example:
return load_examples(
folder=self.config.example_base_path,
role_set={self.alias, "Planner"},
)
return []

def get_plugin_pool(self) -> List[PluginEntry]:
return self.plugin_pool

Expand Down
20 changes: 8 additions & 12 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def load_experience(self):
exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files]
if len(exp_ids) == 0:
self.logger.warning(
"No experience found."
"No experience found.",
)
return

Expand All @@ -253,18 +253,14 @@ def load_experience(self):

exp_file = f"exp_{exp_id}.yaml"
exp_file_path = os.path.join(exp_dir, exp_file)
assert os.path.exists(exp_file_path), (
f"Experience {exp_file} not found. "
)
assert os.path.exists(exp_file_path), f"Experience {exp_file} not found. "

experience = read_yaml(exp_file_path)

assert len(experience["embedding"]) > 0, (
f"Experience {exp_file} has no embedding."
)
assert experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model, (
f"Experience {exp_file} has different embedding model."
)
assert len(experience["embedding"]) > 0, f"Experience {exp_file} has no embedding."
assert (
experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model
), f"Experience {exp_file} has different embedding model."

self.experience_list.append(Experience(**experience))

Expand Down Expand Up @@ -326,13 +322,13 @@ def delete_handcrafted_experience(self, exp_id: str):
@staticmethod
def format_experience_in_prompt(
prompt_template: str,
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
selected_experiences: Optional[List[Experience,]] = None,
):
if selected_experiences is not None and len(selected_experiences) > 0:
return prompt_template.format(
experiences="===================\n"
+ "\n===================\n".join(
[exp.experience_text for exp, _ in selected_experiences],
[exp.experience_text for exp in selected_experiences],
),
)
else:
Expand Down
2 changes: 1 addition & 1 deletion taskweaver/memory/type_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

RoleName = str
RoundState = Literal["finished", "failed", "created"]
SharedMemoryEntryType = Literal["plan", "experience_sub_path"]
SharedMemoryEntryType = Literal["plan", "experience_sub_path", "example_sub_path"]
SharedMemoryEntryScope = Literal["round", "conversation"]
7 changes: 7 additions & 0 deletions taskweaver/misc/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@

def load_examples(
folder: str,
sub_path: Optional[str] = None,
role_set: Optional[Set[str]] = None,
) -> List[Conversation]:
"""
Load all the examples from a folder.
Args:
folder: the folder path.
sub_path: the sub-folder path.
role_set: the roles should be included in the examples.
"""
if sub_path:
folder = path.join(folder, sub_path)
if not path.exists(folder):
raise FileNotFoundError(f"Folder {folder} does not exist.")

example_file_list: List[str] = glob.glob(path.join(folder, "*.yaml"))
example_conv_pool: List[Conversation] = []
for yaml_path in example_file_list:
Expand Down
51 changes: 11 additions & 40 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import os
import types
from json import JSONDecodeError
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional

from injector import inject

from taskweaver.llm import LLMApi
from taskweaver.llm.util import ChatMessageType, format_chat_message
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Conversation, Memory, Post, Round, RoundCompressor
from taskweaver.memory import Memory, Post, Round, RoundCompressor
from taskweaver.memory.attachment import AttachmentType
from taskweaver.memory.experience import Experience, ExperienceGenerator
from taskweaver.memory.experience import ExperienceGenerator
from taskweaver.memory.memory import SharedMemoryEntry
from taskweaver.misc.example import load_examples
from taskweaver.module.event_emitter import SessionEventEmitter
from taskweaver.module.tracing import Tracing, tracing_decorator
from taskweaver.role import PostTranslator, Role
Expand All @@ -25,22 +24,13 @@
class PlannerConfig(RoleConfig):
def _configure(self) -> None:
self._set_name("planner")
app_dir = self.src.app_base_path
self.use_example = self._get_bool("use_example", True)
self.prompt_file_path = self._get_path(
"prompt_file_path",
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"planner_prompt.yaml",
),
)
self.example_base_path = self._get_path(
"example_base_path",
os.path.join(
app_dir,
"planner_examples",
),
)
self.prompt_compression = self._get_bool("prompt_compression", False)
self.compression_prompt_path = self._get_path(
"compression_prompt_path",
Expand Down Expand Up @@ -82,9 +72,6 @@ def __init__(

self.prompt_data = read_yaml(self.config.prompt_file_path)

if self.config.use_example:
self.examples = self.get_examples()

self.instruction_template = self.prompt_data["instruction_template"]

self.response_json_schema = json.loads(self.prompt_data["response_json_schema"])
Expand Down Expand Up @@ -211,11 +198,9 @@ def get_env_context(self) -> str:
def compose_prompt(
self,
rounds: List[Round],
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
) -> List[ChatMessageType]:
experiences = self.format_experience(
template=self.prompt_data["experience_instruction"],
experiences=selected_experiences,
)

chat_history = [
Expand All @@ -225,12 +210,11 @@ def compose_prompt(
),
]

if self.config.use_example and len(self.examples) != 0:
for conv_example in self.examples:
conv_example_in_prompt = self.compose_conversation_for_prompt(
conv_example.rounds,
)
chat_history += conv_example_in_prompt
for conv_example in self.examples:
conv_example_in_prompt = self.compose_conversation_for_prompt(
conv_example.rounds,
)
chat_history += conv_example_in_prompt

summary = None
if self.config.prompt_compression and self.round_compressor is not None:
Expand Down Expand Up @@ -266,19 +250,13 @@ def reply(
self.tracing.set_span_attribute("user_query", user_query)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)

exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")

if exp_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths))
exp_sub_path = exp_sub_paths[0].content
else:
exp_sub_path = ""
selected_experiences = self.load_experience(query=user_query, sub_path=exp_sub_path)
self.role_load_experience(query=user_query, memory=memory)
self.role_load_example(role_set=set(self.recipient_alias_set) | {self.alias, "User"}, memory=memory)

post_proxy = self.event_emitter.create_post_proxy(self.alias)

post_proxy.update_status("composing prompt")
chat_history = self.compose_prompt(rounds, selected_experiences)
chat_history = self.compose_prompt(rounds)

def check_post_validity(post: Post):
missing_elements: List[str] = []
Expand Down Expand Up @@ -392,10 +370,3 @@ def stream_filter(s: Iterable[ChatMessageType]):
self.tracing.set_span_attribute("out.attachments", str(reply_post.attachment_list))

return reply_post

def get_examples(self) -> List[Conversation]:
example_conv_list = load_examples(
self.config.example_base_path,
role_set=set(self.recipient_alias_set) | {self.alias, "User"},
)
return example_conv_list
Loading

0 comments on commit 2d8086d

Please sign in to comment.