Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
liqul committed Sep 9, 2024
1 parent 1d2b229 commit 9cf39cb
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 163 deletions.
96 changes: 0 additions & 96 deletions scripts/experience_mgt.py

This file was deleted.

58 changes: 32 additions & 26 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def _configure(self) -> None:
self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3)

self.use_experience = self._get_bool("use_experience", False)
self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False)
self.experience_dir = self._get_path(
"experience_dir",
os.path.join(
self.src.app_base_path,
"experience",
),
)
self.dynamic_experience_sub_path = self._get_bool("dynamic_experience_sub_path", False)

self.llm_alias = self._get_str("llm_alias", default="", required=False)

Expand Down Expand Up @@ -106,29 +113,26 @@ def __init__(
self.selected_plugin_pool = SelectedPluginPool()

self.experience_generator = experience_generator
self.exp_loaded = False
if self.config.dynamic_experience_filter:
self.exp_filter_str = None
else:
# use the experience folder
self.exp_filter_str = ""
self.experience_loaded_from = None

self.logger.info("CodeGenerator initialized successfully")

def load_experience(self):
if self.exp_filter_str is None or self.exp_loaded:
return
self.experience_generator.set_sub_dir(self.exp_filter_str)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
self.exp_filter_str,
def load_experience(self, sub_path: str = ""):
load_from = os.path.join(self.config.experience_dir, sub_path)
if self.experience_loaded_from is None or self.experience_loaded_from != load_from:
self.experience_loaded_from = load_from
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
sub_path,
),
)
)
self.exp_loaded = True
else:
self.logger.info(f"Experience already loaded from {load_from}.")

def configure_verification(
self,
Expand Down Expand Up @@ -378,11 +382,6 @@ def reply(
# obtain the query from the last round
query = rounds[-1].post_list[-1].message

exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter)
if exp_filter:
self.exp_filter_str = exp_filter[0].content
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

self.tracing.set_span_attribute("query", query)
self.tracing.set_span_attribute("enable_auto_plugin_selection", self.config.enable_auto_plugin_selection)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)
Expand All @@ -391,7 +390,14 @@ def reply(
self.plugin_pool = self.select_plugins_for_prompt(query)

if self.config.use_experience:
self.load_experience()
if not self.config.dynamic_experience_sub_path:
self.load_experience()
else:
exp_sub_path = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_sub_path)
if exp_sub_path:
self.load_experience(exp_sub_path[0])
self.tracing.set_span_attribute("exp_sub_path", exp_sub_path[0])

selected_experiences = self.experience_generator.retrieve_experience(query)
else:
selected_experiences = None
Expand Down
2 changes: 1 addition & 1 deletion taskweaver/ext_role/echo/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def reply(self, memory: Memory, **kwargs: ...) -> Post:

post_proxy.update_attachment(
type=AttachmentType.signal,
message="exp_filter:sub_exp",
message="exp_sub_path:sub_exp",
)

return post_proxy.end()
2 changes: 1 addition & 1 deletion taskweaver/memory/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class AttachmentType(Enum):

# signal
signal = "signal"
exp_filter = "exp_filter"
exp_sub_path = "exp_sub_path"


@dataclass
Expand Down
21 changes: 10 additions & 11 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import os
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Set
from typing import Any, Dict, List, Optional, Tuple

from injector import inject

Expand Down Expand Up @@ -45,10 +44,6 @@ class ExperienceConfig(ModuleConfig):
def _configure(self) -> None:
self._set_name("experience")

self.experience_dir = self._get_path(
"experience_dir",
os.path.join(self.src.app_base_path, "experience"),
)
self.default_exp_prompt_path = self._get_path(
"default_exp_prompt_path",
os.path.join(
Expand Down Expand Up @@ -84,10 +79,14 @@ def __init__(
"run `python -m experience_mgt --refresh` to refresh the experience."
)

self.sub_dir = None
self.experience_dir = None
self.sub_path = None

def set_experience_dir(self, experience_dir: str):
self.experience_dir = experience_dir

def set_sub_dir(self, sub_dir: str):
self.sub_dir = sub_dir
def set_sub_path(self, sub_path: str):
self.sub_path = sub_path

@staticmethod
def _preprocess_conversation_data(
Expand Down Expand Up @@ -316,8 +315,8 @@ def _delete_exp_file(self, exp_file_name: str):
self.logger.info(f"Experience {exp_file_name} not found.")

def get_experience_dir(self):
return os.path.join(self.config.experience_dir, self.sub_dir) \
if self.sub_dir else self.config.experience_dir
assert self.experience_dir is not None, "Experience directory is not set. Call set_experience_dir() first."
return os.path.join(self.experience_dir, self.sub_path) if self.sub_path else self.experience_dir

def delete_experience(self, exp_id: str):
exp_file_name = f"exp_{exp_id}.yaml"
Expand Down
61 changes: 33 additions & 28 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,16 @@ def _configure(self) -> None:
),
)

# experience related
self.use_experience = self._get_bool("use_experience", False)
self.dynamic_experience_filter = self._get_bool("dynamic_experience_filter", False)
self.experience_dir = self._get_path(
"experience_dir",
os.path.join(
app_dir,
"experience",
),
)
self.dynamic_experience_sub_path = self._get_bool("dynamic_experience_sub_path", False)

self.llm_alias = self._get_str("llm_alias", default="", required=False)

Expand Down Expand Up @@ -102,30 +110,26 @@ def __init__(
self.compression_prompt_template = read_yaml(self.config.compression_prompt_path)["content"]

self.experience_generator = experience_generator
self.exp_loaded = False
if self.config.dynamic_experience_filter:
self.exp_filter_str = None
else:
# use the experience folder
self.exp_filter_str = ""
self.experience_loaded_from = None

self.logger.info("Planner initialized successfully")

def load_experience(self):
if self.exp_filter_str is None or self.exp_loaded:
return
self.experience_generator.set_sub_dir(self.exp_filter_str)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, "
"there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
self.exp_filter_str,
def load_experience(self, sub_path: str = ""):
load_from = os.path.join(self.config.experience_dir, sub_path)
if self.experience_loaded_from is None or self.experience_loaded_from != load_from:
self.experience_loaded_from = load_from
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully, there are {} experiences with filter [{}]".format(
len(self.experience_generator.experience_list),
sub_path,
),
)
)

self.exp_loaded = True
else:
self.logger.info(f"Experience already loaded from {load_from}.")

def compose_sys_prompt(self, context: str):
worker_description = ""
Expand Down Expand Up @@ -290,17 +294,18 @@ def reply(

user_query = rounds[-1].user_query

exp_filter = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_filter)
if exp_filter:
self.exp_filter_str = exp_filter[0]
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

self.tracing.set_span_attribute("user_query", user_query)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)
self.tracing.set_span_attribute("exp_filter", self.exp_filter_str)

if self.config.use_experience:
self.load_experience()
if not self.config.dynamic_experience_sub_path:
self.load_experience()
else:
exp_sub_path = rounds[-1].post_list[-1].get_attachment(AttachmentType.exp_sub_path)
if exp_sub_path:
self.load_experience(exp_sub_path[0])
self.tracing.set_span_attribute("exp_sub_path", exp_sub_path[0])

selected_experiences = self.experience_generator.retrieve_experience(user_query)
else:
selected_experiences = None
Expand Down

0 comments on commit 9cf39cb

Please sign in to comment.