Skip to content

Commit

Permalink
run planner exp case success
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxu0307 committed Dec 28, 2023
1 parent eb412db commit 05fe755
Show file tree
Hide file tree
Showing 6 changed files with 813 additions and 808 deletions.
38 changes: 21 additions & 17 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import warnings
from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple

Expand Down Expand Up @@ -35,8 +36,8 @@ class ExperienceConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("experience")

self.session_history_dir = self._get_path(
"session_history_dir",
self.experience_dir = self._get_path(
"experience_dir",
os.path.join(self.src.app_base_path, "experience"),
)
self.default_exp_prompt_path = self._get_path(
Expand All @@ -62,8 +63,7 @@ def __init__(
self.llm_api = llm_api
self.logger = logger

with open(self.config.default_exp_prompt_path, "r") as f:
self.default_prompt_template = f.read()
self.default_prompt_template = read_yaml(self.config.default_exp_prompt_path)["content"]

self.experience_list: List[Experience] = []

Expand All @@ -81,6 +81,8 @@ def remove_id_fields(d):
remove_id_fields(item)

def select_role(conv_data, target_role):
if target_role == "All":
return
for round_data in conv_data:
for idx, post in enumerate(round_data["post_list"]):
if post["send_from"] != target_role and post["send_to"] != target_role:
Expand All @@ -96,9 +98,9 @@ def summarize_experience(
self,
session_id: str,
prompt: Optional[str] = None,
target_role: Literal["Planner", "CodeInterpreter"] = "Planner",
target_role: Literal["Planner", "CodeInterpreter", "All"] = "Planner",
):
raw_exp_file_path = os.path.join(self.config.session_history_dir, f"raw_exp_{session_id}.yaml")
raw_exp_file_path = os.path.join(self.config.experience_dir, f"raw_exp_{session_id}.yaml")
conversation = read_yaml(raw_exp_file_path)

conversation = self._preprocess_conversation_data(conversation, target_role)
Expand All @@ -115,19 +117,20 @@ def summarize_experience(
def summarize_experience_in_batch(
self,
prompt: Optional[str] = None,
target_role: Literal["Planner", "CodeInterpreter"] = "Planner",
target_role: Literal["Planner", "CodeInterpreter", "All"] = "Planner",
):
exp_files = os.listdir(self.config.session_history_dir)
exp_files = os.listdir(self.config.experience_dir)
session_ids = [exp_file.split("_")[2].split(".")[0] for exp_file in exp_files if exp_file.startswith("raw_exp")]

if len(session_ids) == 0:
raise ValueError("No experience found.")
warnings.warn("No experience found. Please type SAVE AS EXP in the chat window to save experience.")
return

for session_id in session_ids:
exp_file_name = f"exp_{session_id}.yaml"
# if the experience file already exists, load it
if not self.config.refresh_experience and exp_file_name in os.listdir(self.config.session_history_dir):
exp_file_path = os.path.join(self.config.session_history_dir, exp_file_name)
if not self.config.refresh_experience and exp_file_name in os.listdir(self.config.experience_dir):
exp_file_path = os.path.join(self.config.experience_dir, exp_file_name)
experience = read_yaml(exp_file_path)
experience_obj = Experience(**experience)
self.experience_list.append(experience_obj)
Expand All @@ -139,7 +142,7 @@ def summarize_experience_in_batch(
experience_text=summarized_experience,
session_id=session_id,
raw_experience_path=os.path.join(
self.config.session_history_dir,
self.config.experience_dir,
f"raw_exp_{session_id}.yaml",
),
)
Expand All @@ -153,7 +156,7 @@ def summarize_experience_in_batch(
self.logger.info("Experience embeddings created. Embeddings number: {}".format(len(exp_embeddings)))

for exp in self.experience_list:
experience_file_path = os.path.join(self.config.session_history_dir, f"exp_{exp.session_id}.yaml")
experience_file_path = os.path.join(self.config.experience_dir, f"exp_{exp.session_id}.yaml")
write_yaml(experience_file_path, exp.to_dict())
self.logger.info("Experience obj saved.")

Expand Down Expand Up @@ -185,14 +188,15 @@ def retrieve_experience(self, user_query: str) -> List[Tuple[Experience, float]]
)

selected_experiences = [(exp, sim) for exp, sim in experience_rank if sim >= self.config.retrieve_threshold]

self.logger.info(f"Retrieved {len(selected_experiences)} experiences.")
self.logger.info(f"Retrieved experiences: {[exp.session_id for exp, sim in selected_experiences]}")
return selected_experiences

def delete_experience(self, session_id: str):
exp_file_name = f"exp_{session_id}.yaml"
if exp_file_name in os.listdir(self.config.session_history_dir):
os.remove(os.path.join(self.config.session_history_dir, exp_file_name))
os.remove(os.path.join(self.config.session_history_dir, f"raw_exp_{session_id}.yaml"))
if exp_file_name in os.listdir(self.config.experience_dir):
os.remove(os.path.join(self.config.experience_dir, exp_file_name))
os.remove(os.path.join(self.config.experience_dir, f"raw_exp_{session_id}.yaml"))
self.logger.info(f"Experience {exp_file_name} deleted.")
else:
self.logger.info(f"Experience {exp_file_name} not found.")
17 changes: 9 additions & 8 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ def __init__(
self.round_compressor = round_compressor
self.compression_prompt_template = read_yaml(self.config.compression_prompt_path)["content"]

self.experience_manager = experience_manager
self.experience_prompt_template = read_yaml(self.config.exp_prompt_path)["content"]
self.experience_manager.summarize_experience_in_batch(
prompt=self.experience_prompt_template,
target_role="Planner",
)
if self.config.use_experience:
self.experience_manager = experience_manager
self.experience_prompt_template = read_yaml(self.config.exp_prompt_path)["content"]
self.experience_manager.summarize_experience_in_batch(
prompt=self.experience_prompt_template,
target_role="All",
)

self.logger.info("Planner initialized successfully")

Expand Down Expand Up @@ -189,7 +190,7 @@ def compose_prompt(
rounds: List[Round],
selected_experiences: Optional[List[Experience]] = None,
) -> List[ChatMessageType]:
if selected_experiences is not None:
if selected_experiences is not None and len(selected_experiences) != 0:
self.experience_instruction = self.prompt_data["experience_instruction"].format(
experiences="\n===================".join([exp.experience_text for exp, sim in selected_experiences]),
)
Expand Down Expand Up @@ -230,7 +231,7 @@ def reply(
assert len(rounds) != 0, "No chat rounds found for planner"

user_query = rounds[-1].user_query
if self.config.use_experience and self.experience_manager is not None:
if self.config.use_experience:
selected_experiences = self.experience_manager.retrieve_experience(user_query)
else:
selected_experiences = None
Expand Down
6 changes: 3 additions & 3 deletions taskweaver/planner/planner_exp_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ content: |-
# About Output
The output should be formatted as below:
- User Query: The user query/task/request for the given conversation.
- Best Practice: What's the best plan to fulfill the user query? What additional information or guidance does Users provide to help the Planner make the correct and clear plan? How does Planner instruct CodeInterpreter to execute the plan?
- Mistakes to Avoid: What's the mistakes should be avoided in the future planning.
- Best Practice: How does Planner make the correct and clear plan? What additional information or guidance does Users provide to help the Planner make the correct and clear plan? How does Planner instruct CodeInterpreter to execute the plan?
- Mistakes to Avoid: What's the mistakes should be avoided in the future planning? Is there any additional information or guidance that should be provided by Planner to make CodeInterpreter generate the correct code in one time?
- Critical Information: If there are critical information for planning, please include them in the output.
- DO NOT add the specific results of this conversation in the output because the experience and lessons learned should be generalized to other conversations.
- DO NOT add the results of this conversation in the output.
- Generate only one experience from the given conversation.
1 change: 1 addition & 0 deletions taskweaver/planner/planner_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,4 @@ experience_instruction: |-
# Experience And Lessons
Before starting planning, please refer to the following experiences and lessons learned from the previous tasks:
{experiences}
You need to borrow the experience and lessons learned from the previous tasks in your current plan.
Loading

0 comments on commit 05fe755

Please sign in to comment.