Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Liqun/refactor role #441

Merged
merged 9 commits into from
Nov 19, 2024
156 changes: 84 additions & 72 deletions taskweaver/role/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os.path
from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional, Set, Tuple, Union
from typing import List, Literal, Optional, Set, Tuple, Union

from injector import Module, inject, provider

Expand Down Expand Up @@ -153,101 +153,113 @@ def format_experience(
else ""
)

def prepare_loading(
self,
use_flag: bool,
dynamic_sub_path: bool,
base_path: str,
memory: Optional[Memory],
loaded_from_attr: str,
item_type: Literal["experience", "example"],
) -> Optional[str]:
"""Prepare for loading by checking configurations and memory, and return load_from path if applicable."""
if not use_flag:
setattr(self, f"{item_type}s", [])
return None

if not os.path.exists(base_path):
raise FileNotFoundError(
f"The default {item_type} base path {base_path} does not exist."
f"The original {item_type} base paths have been changed to `{item_type}s` folder."
f"Please migrate the {item_type}s to the new base path.",
)

sub_path = ""
if dynamic_sub_path:
assert memory is not None, f"Memory should be provided when dynamic_{item_type}_sub_path is True"
sub_paths = memory.get_shared_memory_entries(entry_type=f"{item_type}_sub_path")
if sub_paths:
self.tracing.set_span_attribute(f"{item_type}_sub_path", str(sub_paths))
# todo: handle multiple sub paths
sub_path = sub_paths[0].content
else:
self.logger.info(f"No {item_type} sub path found in memory.")
setattr(self, f"{item_type}s", [])
return None

load_from = os.path.join(base_path, sub_path)
if getattr(self, loaded_from_attr) is not None and getattr(self, loaded_from_attr) == load_from:
self.logger.info(f"{item_type.capitalize()} already loaded from {load_from}.")
return None

setattr(self, loaded_from_attr, load_from)
return sub_path

def role_load_experience(
self,
query: str,
memory: Optional[Memory] = None,
) -> None:
if not self.config.use_experience:
self.experiences = []
sub_path = self.prepare_loading(
self.config.use_experience,
self.config.dynamic_experience_sub_path,
self.config.experience_dir,
memory,
"experience_loaded_from",
"experience",
)
if sub_path is None:
return

if self.experience_generator is None:
raise ValueError(
"Experience generator is not initialized. Each role instance should have its own generator.",
)

experience_sub_path = ""
if self.config.dynamic_experience_sub_path:
assert memory is not None, "Memory should be provided when dynamic_experience_sub_path is True"
experience_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")
if experience_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(experience_sub_paths))
# todo: handle multiple experience sub paths
experience_sub_path = experience_sub_paths[0].content
else:
self.logger.info("No experience sub path found in memory.")
self.experiences = []
return

load_from = os.path.join(self.config.experience_dir, experience_sub_path)
if self.experience_loaded_from is None or self.experience_loaded_from != load_from:
self.experience_loaded_from = load_from
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(experience_sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
self.alias,
len(self.experience_generator.experience_list),
experience_sub_path,
),
)
else:
self.logger.info(f"Experience already loaded from {load_from}.")
self.experience_generator.set_experience_dir(self.config.experience_dir)
self.experience_generator.set_sub_path(sub_path)
self.experience_generator.refresh()
self.experience_generator.load_experience()
self.logger.info(
"Experience loaded successfully for {}, there are {} experiences with filter [{}]".format(
self.alias,
len(self.experience_generator.experience_list),
sub_path,
),
)

experiences = self.experience_generator.retrieve_experience(query)
self.logger.info(f"Retrieved {len(experiences)} experiences for query [{query}]")
self.experiences = [exp for exp, _ in experiences]

# todo: `role_load_example` is similar to `role_load_experience`, consider refactoring
def role_load_example(
self,
role_set: Set[str],
memory: Optional[Memory] = None,
) -> None:
if not self.config.use_example:
self.examples = []
sub_path = self.prepare_loading(
self.config.use_example,
self.config.dynamic_example_sub_path,
self.config.example_base_path,
memory,
"example_loaded_from",
"example",
)
if sub_path is None:
return

if not os.path.exists(self.config.example_base_path):
raise FileNotFoundError(
f"The default example base path {self.config.example_base_path} does not exist."
"The original example base paths have been changed to `examples` folder."
"Please migrate the examples to the new base path.",
)

example_sub_path = ""
if self.config.dynamic_example_sub_path:
assert memory is not None, "Memory should be provided when dynamic_example_sub_path is True"
example_sub_paths = memory.get_shared_memory_entries(entry_type="example_sub_path")
if example_sub_paths:
self.tracing.set_span_attribute("example_sub_path", str(example_sub_paths))
# todo: handle multiple sub paths
example_sub_path = example_sub_paths[0].content
else:
self.logger.info("No example sub path found in memory.")
self.examples = []
return

load_from = os.path.join(self.config.example_base_path, example_sub_path)
if self.example_loaded_from is None or self.example_loaded_from != load_from:
self.example_loaded_from = load_from
self.examples = load_examples(
folder=self.config.example_base_path,
sub_path=example_sub_path,
role_set=role_set,
)
self.logger.info(
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
self.alias,
len(self.examples),
example_sub_path,
),
)
else:
self.logger.info(f"Example already loaded from {load_from}.")
self.examples = load_examples(
folder=self.config.example_base_path,
sub_path=sub_path,
role_set=role_set,
)
self.logger.info(
"Example loaded successfully for {}, there are {} examples with filter [{}]".format(
self.alias,
len(self.examples),
sub_path,
),
)


class RoleModuleConfig(ModuleConfig):
Expand Down