Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the ability to learn multiple skills simultaneously #21

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ agent = Agent(
# connect to a dataset
environment=BasicEnvironment(
ground_truth_dataset=ground_truth_dataset,
ground_truth_column="ground_truth"
ground_truth_columns={"sentiment_classification": "ground_truth"}
),

# define a skill
Expand All @@ -148,10 +148,10 @@ agent = Agent(
default_runtime='openai',

# NOTE! If you have access to GPT-4, you can uncomment the lines bellow for better results
# default_teacher_runtime='openai-gpt4',
# teacher_runtimes = {
# 'openai-gpt4': OpenAIRuntime(model='gpt-4')
# }
# default_teacher_runtime='openai-gpt4',
# teacher_runtimes = {
# 'openai-gpt4': OpenAIRuntime(model='gpt-4')
# }
)

print(agent)
Expand All @@ -160,9 +160,9 @@ print(agent.skills)
agent.learn(learning_iterations=3, accuracy_threshold=0.95)

print('\n=> Run tests ...')
run = agent.apply_skills(predict_dataset)
predictions = agent.run(predict_dataset)
print('\n => Test results:')
print(run)
print(predictions)
```

### 👉 Available skills
Expand Down
139 changes: 64 additions & 75 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union
from adala.environments.base import Environment, BasicEnvironment
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print

from adala.environments.base import Environment, BasicEnvironment, GroundTruthSignal
from adala.datasets import Dataset, DataFrameDataset
from adala.runtimes.base import Runtime, LLMRuntime, LLMRuntimeType, LLMRuntimeModelType
from adala.runtimes.openai import OpenAIRuntime
from adala.memories.base import ShortTermMemory, LongTermMemory
from adala.memories.base import Memory
from adala.skills.base import BaseSkill
from adala.skills.skillset import SkillSet, LinearSkillSet
from adala.utils.logs import print_dataframe, print_text, print_error
from adala.utils.internal_data import InternalDataFrame
from adala.utils.internal_data import InternalDataFrame, InternalDataFrameConcat


class Agent(BaseModel, ABC):
Expand All @@ -23,12 +25,14 @@ class Agent(BaseModel, ABC):
memory (LongTermMemory, optional): The agent's long-term memory. Defaults to None.
runtimes (Dict[str, Runtime], optional): The runtimes available to the agent. Defaults to predefined runtimes.
default_runtime (str): The default runtime used by the agent. Defaults to 'openai'.
teacher_runtimes (Dict[str, Runtime], optional): The runtimes available to the agent's teacher. Defaults to predefined runtimes.
default_teacher_runtime (str): The default runtime used by the agent's teacher. Defaults to 'openai-gpt3'.
"""

environment: Union[InternalDataFrame, Dataset, Environment] = Field(default_factory=DataFrameDataset)
skills: Union[SkillSet, BaseSkill, List[BaseSkill], Dict[str, BaseSkill]]
skills: SkillSet

memory: LongTermMemory = Field(default=None)
memory: Memory = Field(default=None)
runtimes: Optional[Dict[str, Runtime]] = Field(
default_factory=lambda: {
'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
Expand Down Expand Up @@ -90,7 +94,7 @@ def environment_validator(cls, v):
v = BasicEnvironment(dataset=v)
return v

@field_validator('skills')
@field_validator('skills', mode='before')
def skills_validator(cls, v):
"""
Validates and possibly transforms the skills attribute.
Expand All @@ -103,14 +107,11 @@ def skills_validator(cls, v):
"""

if isinstance(v, SkillSet):
pass
return v
elif isinstance(v, BaseSkill):
v = LinearSkillSet(skills={'skill_0': v})
elif isinstance(v, list):
v = LinearSkillSet(skills={f'skill_{i}': skill for i, skill in enumerate(v)})
elif isinstance(v, dict):
v = LinearSkillSet(skills=v)
return v
return LinearSkillSet(skills={v.name: v})
else:
return LinearSkillSet(skills=v)

@model_validator(mode='after')
def verify_input_parameters(self):
Expand Down Expand Up @@ -169,116 +170,104 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
raise ValueError(f'Teacher Runtime "{runtime}" not found.')
return self.teacher_runtimes[runtime]

def apply_skills(
self,
dataset: Union[Dataset, InternalDataFrame],
runtime: Optional[Union[str, Runtime]] = None,
experience: Optional[ShortTermMemory] = None,
) -> ShortTermMemory:
def run(self, dataset: Union[Dataset, InternalDataFrame], runtime: Optional[str] = None) -> InternalDataFrame:
"""
Applies the agent's skills to a given dataset using the specified runtime.
Runs the agent on the specified dataset.

Args:
dataset (Dataset): The dataset to apply skills on.
runtime (str, optional): The runtime to use. Defaults to None.
experience (ShortTermMemory, optional): The agent's short-term memory. Defaults to None.
dataset (Union[Dataset, InternalDataFrame]): The dataset to run the agent on.
runtime (str, optional): The name of the runtime to use. Defaults to None, use the default runtime.

Returns:
ShortTermMemory: The short-term memory resulting from the application of skills.
InternalDataFrame: The dataset with the agent's predictions.
"""
runtime = runtime or self.default_runtime
if isinstance(dataset, InternalDataFrame):
dataset = DataFrameDataset(df=dataset)
if isinstance(runtime, str):
runtime = self.get_runtime(runtime=runtime)
return self.skills.apply(dataset=dataset, runtime=runtime, experience=experience)
runtime = self.get_runtime(runtime=runtime)
predictions = self.skills.apply(dataset, runtime=runtime)
return predictions

def learn(
self,
learning_iterations: int = 3,
accuracy_threshold: float = 0.9,
update_skills: bool = True,
update_memory: bool = True,
request_environment_feedback: bool = True,
experience: Optional[ShortTermMemory] = None,
runtime: Optional[str] = None,
) -> ShortTermMemory:
teacher_runtime: Optional[str] = None,
) -> GroundTruthSignal:
"""
Enables the agent to learn and improve its skills based on interactions with its environment.

Args:
learning_iterations (int, optional): The number of iterations for learning. Defaults to 3.
accuracy_threshold (float, optional): The desired accuracy threshold to reach. Defaults to 0.9.
update_skills (bool, optional): Flag to determine if skills should be updated after learning. Defaults to True.
update_memory (bool, optional): Flag to determine if memory should be updated after learning. Defaults to True.
request_environment_feedback (bool, optional): Flag to determine if feedback should be requested from the environment. Defaults to True.
experience (ShortTermMemory, optional): Initial experience for the learning process. Defaults to None.
runtime (str, optional): The runtime to be used for the learning process. Defaults to None.

teacher_runtime (str, optional): The teacher runtime to be used for the learning process. Defaults to None.
Returns:
ShortTermMemory: The short-term memory after the learning process.
GroundTruthSignal: The ground truth signal.
"""

runtime = self.get_runtime(runtime=runtime)
# TODO: support teacher runtime input, not default
teacher_runtime = self.get_teacher_runtime(runtime=self.default_teacher_runtime)
teacher_runtime = self.get_teacher_runtime(runtime=teacher_runtime)

skills = self.skills.model_copy(deep=True)
dataset = self.environment.as_dataset()

# Apply agent skills to dataset and get experience with predictions
experience = self.apply_skills(dataset=dataset, runtime=runtime, experience=experience)

# Agent select one skill to improve
learned_skill = skills.select_skill_to_improve(experience)
predictions = self.skills.apply(dataset, runtime=runtime)

# Request feedback from environment is necessary
if request_environment_feedback:
self.environment.request_feedback(learned_skill, experience)
ground_truth_signal = None

for iteration in range(learning_iterations):
print_text(f'\n\n=> Iteration #{iteration}: Comparing to ground truth, analyzing and improving ...')

# 1. EVALUATION PHASE: Compare predictions to ground truth
experience = self.environment.compare_to_ground_truth(learned_skill, experience)
# Request feedback from environment is necessary
if request_environment_feedback:
self.environment.request_feedback(self.skills, predictions)

# Compare predictions to ground truth -> get ground truth signal
ground_truth_signal = self.environment.compare_to_ground_truth(self.skills, predictions)
print_text(f'Comparing predictions to ground truth data ...')
print_dataframe(experience.evaluations)
print_dataframe(InternalDataFrameConcat([predictions, ground_truth_signal.match], axis=1))

# Use ground truth signal to find the skill to improve
accuracy = ground_truth_signal.get_accuracy()
train_skill = self.skills.select_skill_to_improve(accuracy, accuracy_threshold)
if not train_skill:
print_text(f'No skill to improve found. Stopping learning process.')
break
# select the worst performing skill
print_text(f'Accuracy = {accuracy[train_skill.name] * 100:0.2f}%', style='bold red')

skill_errors = ground_truth_signal.get_errors(train_skill.name)

# 2. ANALYSIS PHASE: Analyze evaluation experience, optionally use long term memory
print_text(f'Analyze evaluation experience ...')
experience = learned_skill.analyze(
experience=experience,
error_analysis = train_skill.analyze(
predictions=predictions,
errors=skill_errors,
student_runtime=runtime,
teacher_runtime=teacher_runtime,
memory=self.memory
)
print_text(f'Number of errors: {len(experience.errors)}')

print_text(f'Accuracy = {experience.accuracy*100:0.2f}%', style='bold red')
if experience.accuracy >= accuracy_threshold:
print_text(f'Accuracy threshold reached ({experience.accuracy} >= {accuracy_threshold})')
break
print_text(f'Error analysis for skill "{train_skill.name}":\n')
print_text(error_analysis, style='green')
if self.memory and update_memory:
self.memory.remember(error_analysis, self.skills)

# 3. IMPROVEMENT PHASE: Improve skills based on analysis
print_text(f"Improve \"{learned_skill.name}\" skill based on analysis ...")
experience = learned_skill.improve(
experience=experience,
print_text(f"Improve \"{train_skill.name}\" skill based on analysis ...")
train_skill.improve(
error_analysis=error_analysis,
runtime=teacher_runtime,
update_instructions=True
)
print_text(f'Updated instructions for skill "{learned_skill.name}":\n')
print_text(learned_skill.instructions, style='bold green')
print_text(f'Updated instructions for skill "{train_skill.name}":\n')
print_text(train_skill.instructions, style='bold green')

# 4. RE-APPLY PHASE: Re-apply skills to dataset
print_text(f"Re-apply {learned_skill.name} skill to dataset ...")
experience = learned_skill.apply(dataset, runtime, experience=experience)

# Update skills and memory based on experience
if update_skills:
self.skills = skills

if self.memory and update_memory:
self.memory.remember(experience, self.skills)
print_text(f"Re-apply {train_skill.name} skill to dataset ...")
self.skills[train_skill.name] = train_skill
predictions = self.skills.apply(predictions, runtime=runtime, improved_skill=train_skill.name)

print_text('Train is done!')
return experience
return ground_truth_signal
Loading