Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
liqul committed Sep 6, 2024
1 parent 9d8acd9 commit 1d2b229
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 41 deletions.
37 changes: 29 additions & 8 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _configure(self) -> None:
self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3)

self.use_experience = self._get_bool("use_experience", False)
self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False)

self.llm_alias = self._get_str("llm_alias", default="", required=False)

Expand Down Expand Up @@ -104,17 +105,31 @@ def __init__(
logger.info("Plugin embeddings loaded")
self.selected_plugin_pool = SelectedPluginPool()

if self.config.use_experience:
self.experience_generator = experience_generator
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences".format(len(self.experience_generator.experience_list)),
)
self.experience_generator = experience_generator
self.exp_loaded = False
if self.config.dynamic_experience_filter:
self.exp_filter_str = None
else:
# use the experience folder
self.exp_filter_str = ""

self.logger.info("CodeGenerator initialized successfully")

def load_experience(self):
if self.exp_filter_str is None or self.exp_loaded:
return
self.experience_generator.set_sub_dir(self.exp_filter_str)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
self.exp_filter_str,
)
)
self.exp_loaded = True

def configure_verification(
self,
code_verification_on: bool,
Expand Down Expand Up @@ -363,6 +378,11 @@ def reply(
# obtain the query from the last round
query = rounds[-1].post_list[-1].message

exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter)
if exp_filter:
self.exp_filter_str = exp_filter[0].content
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

self.tracing.set_span_attribute("query", query)
self.tracing.set_span_attribute("enable_auto_plugin_selection", self.config.enable_auto_plugin_selection)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)
Expand All @@ -371,6 +391,7 @@ def reply(
self.plugin_pool = self.select_plugins_for_prompt(query)

if self.config.use_experience:
self.load_experience()
selected_experiences = self.experience_generator.retrieve_experience(query)
else:
selected_experiences = None
Expand Down
6 changes: 6 additions & 0 deletions taskweaver/ext_role/echo/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.module.event_emitter import SessionEventEmitter
from taskweaver.module.tracing import Tracing
from taskweaver.role import Role
Expand Down Expand Up @@ -41,4 +42,9 @@ def reply(self, memory: Memory, **kwargs: ...) -> Post:
self.config.decorator + last_post.message + self.config.decorator,
)

post_proxy.update_attachment(
type=AttachmentType.signal,
message="exp_filter:sub_exp",
)

return post_proxy.end()
4 changes: 4 additions & 0 deletions taskweaver/memory/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class AttachmentType(Enum):
# board info
board = "board"

# signal
signal = "signal"
exp_filter = "exp_filter"


@dataclass
class Attachment:
Expand Down
58 changes: 39 additions & 19 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import os
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Set

from injector import inject

Expand Down Expand Up @@ -83,6 +84,11 @@ def __init__(
"run `python -m experience_mgt --refresh` to refresh the experience."
)

self.sub_dir = None

def set_sub_dir(self, sub_dir: str):
self.sub_dir = sub_dir

@staticmethod
def _preprocess_conversation_data(
conv_data: dict,
Expand All @@ -109,7 +115,9 @@ def summarize_experience(
exp_id: str,
prompt: Optional[str] = None,
):
raw_exp_file_path = os.path.join(self.config.experience_dir, f"raw_exp_{exp_id}.yaml")
exp_dir = self.get_experience_dir()

raw_exp_file_path = os.path.join(exp_dir, f"raw_exp_{exp_id}.yaml")
conversation = read_yaml(raw_exp_file_path)

conversation = self._preprocess_conversation_data(conversation)
Expand Down Expand Up @@ -145,10 +153,12 @@ def refresh(
self,
prompt: Optional[str] = None,
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")
exp_dir = self.get_experience_dir()

exp_files = os.listdir(self.config.experience_dir)
if not os.path.exists(exp_dir):
raise ValueError(f"Experience directory {exp_dir} does not exist.")

exp_files = os.listdir(exp_dir)

raw_exp_ids = [
os.path.splitext(os.path.basename(exp_file))[0].split("_")[2]
Expand Down Expand Up @@ -176,10 +186,10 @@ def refresh(
for idx, exp_id in enumerate(exp_ids):
rebuild_flag = False
exp_file_name = f"exp_{exp_id}.yaml"
if exp_file_name not in os.listdir(self.config.experience_dir):
if exp_file_name not in os.listdir(exp_dir):
rebuild_flag = True
else:
exp_file_path = os.path.join(self.config.experience_dir, exp_file_name)
exp_file_path = os.path.join(exp_dir, exp_file_name)
experience = read_yaml(exp_file_path)
if (
experience["embedding_model"] != self.llm_api.embedding_service.config.embedding_model
Expand All @@ -194,13 +204,13 @@ def refresh(
experience_text=summarized_experience,
exp_id=exp_id,
raw_experience_path=os.path.join(
self.config.experience_dir,
exp_dir,
f"raw_exp_{exp_id}.yaml",
),
)
elif exp_id in handcrafted_exp_ids:
handcrafted_exp_file_path = os.path.join(
self.config.experience_dir,
exp_dir,
f"handcrafted_exp_{exp_id}.yaml",
)
experience_obj = Experience.from_dict(read_yaml(handcrafted_exp_file_path))
Expand All @@ -218,21 +228,21 @@ def refresh(
for i, exp in enumerate(to_be_embedded):
exp.embedding = exp_embeddings[i]
exp.embedding_model = self.llm_api.embedding_service.config.embedding_model
experience_file_path = os.path.join(self.config.experience_dir, f"exp_{exp.exp_id}.yaml")
experience_file_path = os.path.join(exp_dir, f"exp_{exp.exp_id}.yaml")
write_yaml(experience_file_path, exp.to_dict())

self.logger.info("Experience obj saved.")

@tracing_decorator
def load_experience(
self,
):
if not os.path.exists(self.config.experience_dir):
raise ValueError(f"Experience directory {self.config.experience_dir} does not exist.")
def load_experience(self):
exp_dir = self.get_experience_dir()

if not os.path.exists(exp_dir):
raise ValueError(f"Experience directory {exp_dir} does not exist.")

original_exp_files = [
exp_file
for exp_file in os.listdir(self.config.experience_dir)
for exp_file in os.listdir(exp_dir)
if exp_file.startswith("raw_exp_") or exp_file.startswith("handcrafted_exp_")
]
exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files]
Expand All @@ -245,8 +255,12 @@ def load_experience(
return

for exp_id in exp_ids:
exp_id_exists = exp_id in [exp.exp_id for exp in self.experience_list]
if exp_id_exists:
continue

exp_file = f"exp_{exp_id}.yaml"
exp_file_path = os.path.join(self.config.experience_dir, exp_file)
exp_file_path = os.path.join(exp_dir, exp_file)
assert os.path.exists(exp_file_path), (
f"Experience {exp_file} not found. " + self.exception_message_for_refresh
)
Expand Down Expand Up @@ -293,12 +307,18 @@ def retrieve_experience(self, user_query: str) -> List[Tuple[Experience, float]]
return selected_experiences

def _delete_exp_file(self, exp_file_name: str):
if exp_file_name in os.listdir(self.config.experience_dir):
os.remove(os.path.join(self.config.experience_dir, exp_file_name))
exp_dir = self.get_experience_dir()

if exp_file_name in os.listdir(exp_dir):
os.remove(os.path.join(exp_dir, exp_file_name))
self.logger.info(f"Experience {exp_file_name} deleted.")
else:
self.logger.info(f"Experience {exp_file_name} not found.")

def get_experience_dir(self):
return os.path.join(self.config.experience_dir, self.sub_dir) \
if self.sub_dir else self.config.experience_dir

def delete_experience(self, exp_id: str):
exp_file_name = f"exp_{exp_id}.yaml"
self._delete_exp_file(exp_file_name)
Expand Down
41 changes: 32 additions & 9 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _configure(self) -> None:
)

self.use_experience = self._get_bool("use_experience", False)
self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False)

self.llm_alias = self._get_str("llm_alias", default="", required=False)

Expand Down Expand Up @@ -100,18 +101,32 @@ def __init__(
self.round_compressor = round_compressor
self.compression_prompt_template = read_yaml(self.config.compression_prompt_path)["content"]

if self.config.use_experience:
assert experience_generator is not None, "Experience generator is required when use_experience is True"
self.experience_generator = experience_generator
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences".format(len(self.experience_generator.experience_list)),
)
self.experience_generator = experience_generator
self.exp_loaded = False
if self.config.dynamic_experience_filter:
self.exp_filter_str = None
else:
# use the experience folder
self.exp_filter_str = ""

self.logger.info("Planner initialized successfully")

def load_experience(self):
if self.exp_filter_str is None or self.exp_loaded:
return
self.experience_generator.set_sub_dir(self.exp_filter_str)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
self.exp_filter_str,
)
)

self.exp_loaded = True

def compose_sys_prompt(self, context: str):
worker_description = ""
for alias, role in self.workers.items():
Expand Down Expand Up @@ -274,10 +289,18 @@ def reply(
assert len(rounds) != 0, "No chat rounds found for planner"

user_query = rounds[-1].user_query

exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter)
if exp_filter:
self.exp_filter_str = exp_filter[0]
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

self.tracing.set_span_attribute("user_query", user_query)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

if self.config.use_experience:
self.load_experience()
selected_experiences = self.experience_generator.retrieve_experience(user_query)
else:
selected_experiences = None
Expand Down
40 changes: 36 additions & 4 deletions taskweaver/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Any, Dict, List, Literal, Optional

from injector import Injector, inject
from scipy.stats import logser

from taskweaver.config.module_config import ModuleConfig
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Memory, Post, Round
from taskweaver.memory import Memory, Post, Round, Attachment
from taskweaver.memory.attachment import AttachmentType
from taskweaver.module.event_emitter import SessionEventEmitter, SessionEventHandler
from taskweaver.module.tracing import Tracing, tracing_decorator, tracing_decorator_non_class
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
self.memory = Memory(session_id=self.session_id)

self.session_var: Dict[str, str] = {}
self.session_signal: Dict[AttachmentType, str] = {}

self.event_emitter = self.session_injector.get(SessionEventEmitter)
self.session_injector.binder.bind(SessionEventEmitter, self.event_emitter)
Expand All @@ -100,11 +102,21 @@ def __init__(
if role_name not in role_registry.get_role_name_list():
raise ValueError(f"Unknown role {role_name}")
role_entry = self.role_registry.get(role_name)
role_instance = self.session_injector.create_object(role_entry.module, {"role_entry": role_entry})
role_instance = self.session_injector.create_object(
role_entry.module,
{
"role_entry": role_entry
}
)
self.worker_instances[role_instance.get_alias()] = role_instance

if "planner" in self.config.roles:
self.planner = self.session_injector.create_object(Planner, {"workers": self.worker_instances})
self.planner = self.session_injector.create_object(
Planner,
{
"workers": self.worker_instances
}
)
self.session_injector.binder.bind(Planner, self.planner)

self.max_internal_chat_round_num = self.config.max_internal_chat_round_num
Expand Down Expand Up @@ -163,6 +175,16 @@ def _send_text_message(

@tracing_decorator_non_class
def _send_message(recipient: str, post: Post) -> Post:
# add session signal to the post
if self.session_signal:
for signal_type in self.session_signal:
post.add_attachment(
Attachment.create(
type=signal_type,
content=self.session_signal[signal_type],
)
)

self.tracing.set_span_attribute("in.from", post.send_from)
self.tracing.set_span_attribute("in.recipient", recipient)
self.tracing.set_span_attribute("in.message", post.message)
Expand Down Expand Up @@ -191,7 +213,17 @@ def _send_message(recipient: str, post: Post) -> Post:

board_attachment = reply_post.get_attachment(AttachmentType.board)
if len(board_attachment) > 0:
chat_round.write_board(reply_post.send_from, reply_post.get_attachment(AttachmentType.board)[0])
chat_round.write_board(
reply_post.send_from,
reply_post.get_attachment(AttachmentType.board)[0]
)

signal_attachments = reply_post.get_attachment(AttachmentType.signal)
for signal in signal_attachments:
signal_type, signal_value = signal.split(":")
self.logger.info(f"Session signal: {signal_type}={signal_value}")
# signal_type must be in AttachmentType
self.session_signal[AttachmentType(signal_type)] = signal_value

return reply_post

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"


@pytest.mark.skipif(True, reason="Test doesn't work in Github Actions.")
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions.")
def test_experience_retrieval():
app_injector = Injector([LoggingModule])
app_config = AppConfigSource(
Expand Down

0 comments on commit 1d2b229

Please sign in to comment.