Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experience Selection [Ready] #406

Merged
merged 13 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 0 additions & 96 deletions scripts/experience_mgt.py

This file was deleted.

34 changes: 12 additions & 22 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import json
import os
from typing import List, Optional
from typing import List, Optional, Tuple

from injector import inject

Expand Down Expand Up @@ -55,8 +55,6 @@ def _configure(self) -> None:
)
self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3)

self.use_experience = self._get_bool("use_experience", False)

self.llm_alias = self._get_str("llm_alias", default="", required=False)


Expand Down Expand Up @@ -104,14 +102,7 @@ def __init__(
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

if self.config.use_experience:
self.experience_generator = experience_generator
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences".format(len(self.experience_generator.experience_list)),
)
self.experience_generator = experience_generator

self.logger.info("CodeGenerator initialized successfully")

Expand Down Expand Up @@ -166,15 +157,11 @@ def compose_prompt(
self,
rounds: List[Round],
plugins: List[PluginEntry],
selected_experiences: Optional[List[Experience]] = None,
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
) -> List[ChatMessageType]:
experiences = (
self.experience_generator.format_experience_in_prompt(
self.prompt_data["experience_instruction"],
selected_experiences,
)
if self.config.use_experience
else ""
experiences = self.format_experience(
template=self.prompt_data["experience_instruction"],
experiences=selected_experiences,
)

chat_history = [
Expand Down Expand Up @@ -370,10 +357,13 @@ def reply(
if self.config.enable_auto_plugin_selection:
self.plugin_pool = self.select_plugins_for_prompt(query)

if self.config.use_experience:
selected_experiences = self.experience_generator.retrieve_experience(query)
exp_sub_path = rounds[-1].post_list[-1].get_attachment(AttachmentType._signal_exp_sub_path)
if exp_sub_path:
self.tracing.set_span_attribute("exp_sub_path", exp_sub_path[0])
exp_sub_path = exp_sub_path[0]
else:
selected_experiences = None
exp_sub_path = ""
selected_experiences = self.load_experience(query=query, sub_path=exp_sub_path)

prompt = self.compose_prompt(rounds, self.plugin_pool, selected_experiences)
self.tracing.set_span_attribute("prompt", json.dumps(prompt, indent=2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ requirements: |-

experience_instruction: |-
## Experience And Lessons
Before generating Python code, please refer to the experiences and lessons learned from the previous tasks:
Before generating code, please learn from the following past experiences and lessons:
{experiences}
You must use the experiences and lessons learned to generate the Python code.
You must apply them in code generation.
6 changes: 6 additions & 0 deletions taskweaver/ext_role/echo/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.module.event_emitter import SessionEventEmitter
from taskweaver.module.tracing import Tracing
from taskweaver.role import Role
Expand Down Expand Up @@ -41,4 +42,9 @@ def reply(self, memory: Memory, **kwargs: ...) -> Post:
self.config.decorator + last_post.message + self.config.decorator,
)

post_proxy.update_attachment(
type=AttachmentType.signal,
message="exp_sub_path:sub_exp",
liqul marked this conversation as resolved.
Show resolved Hide resolved
)

return post_proxy.end()
4 changes: 4 additions & 0 deletions taskweaver/memory/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class AttachmentType(Enum):
# board info
board = "board"

# signal
signal = "signal"
_signal_exp_sub_path = "_signal_exp_sub_path"


@dataclass
class Attachment:
Expand Down
63 changes: 41 additions & 22 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ class ExperienceConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("experience")

self.experience_dir = self._get_path(
"experience_dir",
os.path.join(self.src.app_base_path, "experience"),
)
self.default_exp_prompt_path = self._get_path(
"default_exp_prompt_path",
os.path.join(
Expand Down Expand Up @@ -83,6 +79,15 @@ def __init__(
"run `python -m experience_mgt --refresh` to refresh the experience."
)

self.experience_dir = None
self.sub_path = None

def set_experience_dir(self, experience_dir: str):
self.experience_dir = experience_dir

def set_sub_path(self, sub_path: str):
self.sub_path = sub_path
Jack-Q marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _preprocess_conversation_data(
conv_data: dict,
Expand All @@ -109,7 +114,9 @@ def summarize_experience(
exp_id: str,
prompt: Optional[str] = None,
):
raw_exp_file_path = os.path.join(self.config.experience_dir, f"raw_exp_{exp_id}.yaml")
exp_dir = self.get_experience_dir()

raw_exp_file_path = os.path.join(exp_dir, f"raw_exp_{exp_id}.yaml")
conversation = read_yaml(raw_exp_file_path)

conversation = self._preprocess_conversation_data(conversation)
Expand Down Expand Up @@ -145,10 +152,12 @@ def refresh(
self,
prompt: Optional[str] = None,
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")
exp_dir = self.get_experience_dir()

if not os.path.exists(exp_dir):
raise ValueError(f"Experience directory {exp_dir} does not exist.")

exp_files = os.listdir(self.config.experience_dir)
exp_files = os.listdir(exp_dir)

raw_exp_ids = [
os.path.splitext(os.path.basename(exp_file))[0].split("_")[2]
Expand Down Expand Up @@ -176,10 +185,10 @@ def refresh(
for idx, exp_id in enumerate(exp_ids):
rebuild_flag = False
exp_file_name = f"exp_{exp_id}.yaml"
if exp_file_name not in os.listdir(self.config.experience_dir):
if exp_file_name not in os.listdir(exp_dir):
rebuild_flag = True
else:
exp_file_path = os.path.join(self.config.experience_dir, exp_file_name)
exp_file_path = os.path.join(exp_dir, exp_file_name)
experience = read_yaml(exp_file_path)
if (
experience["embedding_model"] != self.llm_api.embedding_service.config.embedding_model
Expand All @@ -194,13 +203,13 @@ def refresh(
experience_text=summarized_experience,
exp_id=exp_id,
raw_experience_path=os.path.join(
self.config.experience_dir,
exp_dir,
f"raw_exp_{exp_id}.yaml",
),
)
elif exp_id in handcrafted_exp_ids:
handcrafted_exp_file_path = os.path.join(
self.config.experience_dir,
exp_dir,
f"handcrafted_exp_{exp_id}.yaml",
)
experience_obj = Experience.from_dict(read_yaml(handcrafted_exp_file_path))
Expand All @@ -218,21 +227,21 @@ def refresh(
for i, exp in enumerate(to_be_embedded):
exp.embedding = exp_embeddings[i]
exp.embedding_model = self.llm_api.embedding_service.config.embedding_model
experience_file_path = os.path.join(self.config.experience_dir, f"exp_{exp.exp_id}.yaml")
experience_file_path = os.path.join(exp_dir, f"exp_{exp.exp_id}.yaml")
write_yaml(experience_file_path, exp.to_dict())

self.logger.info("Experience obj saved.")

@tracing_decorator
def load_experience(
self,
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")
def load_experience(self):
exp_dir = self.get_experience_dir()

if not os.path.exists(exp_dir):
raise ValueError(f"Experience directory {exp_dir} does not exist.")

original_exp_files = [
exp_file
for exp_file in os.listdir(self.config.experience_dir)
for exp_file in os.listdir(exp_dir)
if exp_file.startswith("raw_exp_") or exp_file.startswith("handcrafted_exp_")
]
exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files]
Expand All @@ -245,8 +254,12 @@ def load_experience(
return

for exp_id in exp_ids:
exp_id_exists = exp_id in [exp.exp_id for exp in self.experience_list]
if exp_id_exists:
continue

exp_file = f"exp_{exp_id}.yaml"
exp_file_path = os.path.join(self.config.experience_dir, exp_file)
exp_file_path = os.path.join(exp_dir, exp_file)
assert os.path.exists(exp_file_path), (
f"Experience {exp_file} not found. " + self.exception_message_for_refresh
)
Expand Down Expand Up @@ -293,12 +306,18 @@ def retrieve_experience(self, user_query: str) -> List[Tuple[Experience, float]]
return selected_experiences

def _delete_exp_file(self, exp_file_name: str):
if exp_file_name in os.listdir(self.config.experience_dir):
os.remove(os.path.join(self.config.experience_dir, exp_file_name))
exp_dir = self.get_experience_dir()

if exp_file_name in os.listdir(exp_dir):
os.remove(os.path.join(exp_dir, exp_file_name))
self.logger.info(f"Experience {exp_file_name} deleted.")
else:
self.logger.info(f"Experience {exp_file_name} not found.")

def get_experience_dir(self):
assert self.experience_dir is not None, "Experience directory is not set. Call set_experience_dir() first."
return os.path.join(self.experience_dir, self.sub_path) if self.sub_path else self.experience_dir

def delete_experience(self, exp_id: str):
exp_file_name = f"exp_{exp_id}.yaml"
self._delete_exp_file(exp_file_name)
Expand Down
Loading