Skip to content

Commit

Permalink
Add guidance,langchain,openai runtimes; multi-I/O skills
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Nov 10, 2023
1 parent 8ac8cc9 commit 1b69500
Show file tree
Hide file tree
Showing 25 changed files with 1,095 additions and 367 deletions.
29 changes: 11 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,52 +104,45 @@ Click [here](./examples/quickstart.ipynb) to see an extended quickstart example.
import pandas as pd

from adala.agents import Agent
from adala.datasets import DataFrameDataset
from adala.environments import BasicEnvironment
from adala.environments import StaticEnvironment
from adala.skills import ClassificationSkill
from adala.runtimes import OpenAIRuntime
from adala.runtimes import OpenAIChatRuntime
from rich import print

# Train dataset
ground_truth_df = pd.DataFrame([
train_df = pd.DataFrame([
["It was the negative first impressions, and then it started working.", "Positive"],
["Not loud enough and doesn't turn on like it should.", "Negative"],
["I don't know what to say.", "Neutral"],
["Manager was rude, but the most important that mic shows very flat frequency response.", "Positive"],
["The phone doesn't seem to accept anything except CBR mp3s.", "Negative"],
["I tried it before, I bought this device for my son.", "Neutral"],
], columns=["text", "ground_truth"])
], columns=["text", "sentiment"])

# Test dataset
predict_df = pd.DataFrame([
test_df = pd.DataFrame([
"All three broke within two months of use.",
"The device worked for a long time, can't say anything bad.",
"Just a random line of text."
], columns=["text"])

ground_truth_dataset = DataFrameDataset(df=ground_truth_df)
predict_dataset = DataFrameDataset(df=predict_df)

agent = Agent(
# connect to a dataset
environment=BasicEnvironment(
ground_truth_dataset=ground_truth_dataset,
ground_truth_columns={"sentiment_classification": "ground_truth"}
),
environment=StaticEnvironment(df=train_df),

# define a skill
skills=ClassificationSkill(
name='sentiment_classification',
name='sentiment',
instructions="Label text as positive, negative or neutral.",
labels=["Positive", "Negative", "Neutral"],
input_data_field='text'
input_template="Text: {text}",
output_template="Sentiment: {sentiment}"
),

# define all the different runtimes your skills may use
runtimes = {
# You can specify your OPENAI API KEY here via `OpenAIRuntime(..., api_key='your-api-key')`
'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
'openai-gpt3': OpenAIRuntime(model='gpt-3.5-turbo')
'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
},
default_runtime='openai',

Expand All @@ -166,7 +159,7 @@ print(agent.skills)
agent.learn(learning_iterations=3, accuracy_threshold=0.95)

print('\n=> Run tests ...')
predictions = agent.run(predict_dataset)
predictions = agent.run(test_df)
print('\n => Test results:')
print(predictions)
```
Expand Down
116 changes: 60 additions & 56 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print

from adala.environments.base import Environment, BasicEnvironment, GroundTruthSignal
from adala.environments.base import Environment, StaticEnvironment, GroundTruthSignal
from adala.datasets import Dataset, DataFrameDataset
from adala.runtimes.base import Runtime, LLMRuntime, LLMRuntimeType, LLMRuntimeModelType
from adala.runtimes.openai import OpenAIRuntime
# from adala.runtimes.openai import OpenAIRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.skills._base import Skill
from adala.memories.base import Memory
from adala.skills.base import BaseSkill
from adala.skills.skillset import SkillSet, LinearSkillSet
from adala.utils.logs import print_dataframe, print_text, print_error
from adala.utils.internal_data import InternalDataFrame, InternalDataFrameConcat
from adala.skills.collection.analyze_errors import AnalyzeLLMPromptErrorsExactMatch
from adala.skills.collection.improve_llm import ImproveLLMInstructions


class Agent(BaseModel, ABC):
Expand All @@ -35,7 +39,7 @@ class Agent(BaseModel, ABC):
memory: Memory = Field(default=None)
runtimes: Optional[Dict[str, Runtime]] = Field(
default_factory=lambda: {
'openai': OpenAIRuntime(model='gpt-3.5-turbo-instruct'),
'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
# 'llama2': LLMRuntime(
# llm_runtime_type=LLMRuntimeModelType.Transformers,
# llm_params={
Expand All @@ -47,7 +51,7 @@ class Agent(BaseModel, ABC):
)
teacher_runtimes: Optional[Dict[str, Runtime]] = Field(
default_factory=lambda: {
'openai-gpt3': OpenAIRuntime(model='gpt-3.5-turbo'),
'openai-gpt3': OpenAIChatRuntime(model='gpt-3.5-turbo'),
# 'openai-gpt4': OpenAIRuntime(model='gpt-4')
}
)
Expand Down Expand Up @@ -89,9 +93,7 @@ def environment_validator(cls, v) -> Environment:
Environment: The validated environment.
"""
if isinstance(v, InternalDataFrame):
v = DataFrameDataset(df=v)
if isinstance(v, Dataset):
v = BasicEnvironment(dataset=v)
v = StaticEnvironment(df=v)
return v

@field_validator('skills', mode='before')
Expand All @@ -105,13 +107,12 @@ def skills_validator(cls, v) -> SkillSet:
Returns:
SkillSet: The validated set of skills.
"""

if isinstance(v, SkillSet):
return v
elif isinstance(v, BaseSkill):
return LinearSkillSet(skills={v.name: v})
elif isinstance(v, Skill):
return LinearSkillSet(skills=[v])
else:
return LinearSkillSet(skills=v)
raise ValueError(f"skills must be of type SkillSet or Skill, not {type(v)}")

@model_validator(mode='after')
def verify_input_parameters(self):
Expand Down Expand Up @@ -172,23 +173,23 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:

def run(
self,
dataset: Optional[Union[Dataset, InternalDataFrame]] = None,
input: InternalDataFrame = None,
runtime: Optional[str] = None
) -> InternalDataFrame:
"""
Runs the agent on the specified dataset.
Args:
dataset (Union[Dataset, InternalDataFrame]): The dataset to run the agent on.
input (InternalDataFrame): The dataset to run the agent on.
runtime (str, optional): The name of the runtime to use. Defaults to None, use the default runtime.
Returns:
InternalDataFrame: The dataset with the agent's predictions.
"""
if dataset is None:
dataset = self.environment.as_dataset()
if input is None:
input = self.environment.get_data_batch()
runtime = self.get_runtime(runtime=runtime)
predictions = self.skills.apply(dataset, runtime=runtime)
predictions = self.skills.apply(input, runtime=runtime)
return predictions

def learn(
Expand All @@ -197,8 +198,8 @@ def learn(
accuracy_threshold: float = 0.9,
update_memory: bool = True,
request_environment_feedback: bool = True,
wait_for_environment_feedback: Optional[float] = None,
num_predictions_feedback: Optional[int] = None,
wait_for_feedback: Optional[float] = True,
num_feedbacks: Optional[int] = None,
runtime: Optional[str] = None,
teacher_runtime: Optional[str] = None,
) -> GroundTruthSignal:
Expand All @@ -210,8 +211,8 @@ def learn(
accuracy_threshold (float, optional): The desired accuracy threshold to reach. Defaults to 0.9.
update_memory (bool, optional): Flag to determine if memory should be updated after learning. Defaults to True.
request_environment_feedback (bool, optional): Flag to determine if feedback should be requested from the environment. Defaults to True.
wait_for_environment_feedback (float, optional): The timeout in seconds to wait for environment feedback. Defaults to None.
num_predictions_feedback (int, optional): The number of predictions to request feedback for. Defaults to None.
wait_for_feedback (float, optional): The timeout in seconds to wait for environment feedback. Defaults to None.
num_feedbacks (int, optional): The number of predictions to request feedback for. Defaults to None.
runtime (str, optional): The runtime to be used for the learning process. Defaults to None.
teacher_runtime (str, optional): The teacher runtime to be used for the learning process. Defaults to None.
Returns:
Expand All @@ -221,10 +222,10 @@ def learn(
runtime = self.get_runtime(runtime=runtime)
teacher_runtime = self.get_teacher_runtime(runtime=teacher_runtime)

dataset = self.environment.as_dataset()
data_batch = self.environment.get_data_batch()

# Apply agent skills to dataset and get experience with predictions
predictions = self.skills.apply(dataset, runtime=runtime)
predictions = self.skills.apply(data_batch, runtime=runtime)

ground_truth_signal = None

Expand All @@ -233,54 +234,57 @@ def learn(

# Request feedback from environment is necessary
if request_environment_feedback:
if num_predictions_feedback is not None:
# predictions_for_feedback = predictions.sample(num_predictions_feedback)
predictions_for_feedback = predictions.head(num_predictions_feedback)
else:
predictions_for_feedback = predictions
self.environment.request_feedback(self.skills, predictions_for_feedback)
self.environment.request_feedback(self.skills, predictions, num_feedbacks, wait_for_feedback)

# Compare predictions to ground truth -> get ground truth signal
ground_truth_signal = self.environment.compare_to_ground_truth(
self.skills,
predictions,
wait=wait_for_environment_feedback
)
ground_truth_signal = self.environment.compare_to_ground_truth(self.skills, predictions)

print_text(f'Comparing predictions to ground truth data ...')
print_dataframe(InternalDataFrameConcat([predictions, ground_truth_signal.match], axis=1))

# Use ground truth signal to find the skill to improve
accuracy = ground_truth_signal.get_accuracy()
train_skill = self.skills.select_skill_to_improve(accuracy, accuracy_threshold)
if not train_skill:
train_skill_name, train_skill_output = '', ''
for skill_output, skill_name in self.skills.get_skill_outputs().items():
if accuracy[skill_output] < accuracy_threshold:
train_skill_name, train_skill_output = skill_name, skill_output
break

if not train_skill_name:
print_text(f'No skill to improve found. Stopping learning process.')
break

train_skill = self.skills[train_skill_name]
# select the worst performing skill
print_text(f'Accuracy = {accuracy[train_skill.name] * 100:0.2f}%', style='bold red')
print_text(f'Output to improve: "{train_skill_output}" (Skill="{train_skill_name}")\n'
f'Accuracy = {accuracy[train_skill_output] * 100:0.2f}%', style='bold red')

skill_errors = ground_truth_signal.get_errors(train_skill.name)
skill_errors = ground_truth_signal.get_errors(train_skill_output).rename('_ground_truth')
skill_errors = InternalDataFrameConcat((skill_errors, predictions), axis=1, join='inner')
print(f'Errors for skill "{train_skill_name}":')
print_dataframe(skill_errors)

# 2. ANALYSIS PHASE: Analyze evaluation experience, optionally use long term memory
print_text(f'Analyze evaluation experience ...')
error_analysis = train_skill.analyze(
predictions=predictions,
errors=skill_errors,
student_runtime=runtime,
teacher_runtime=teacher_runtime,
memory=self.memory
)
print_text(f'Error analysis for skill "{train_skill.name}":\n')
print_text(error_analysis, style='green')
if self.memory and update_memory:
self.memory.remember(error_analysis, self.skills)

# 3. IMPROVEMENT PHASE: Improve skills based on analysis
print_text(f"Improve \"{train_skill.name}\" skill based on analysis ...")
train_skill.improve(
error_analysis=error_analysis,
runtime=teacher_runtime,
)

teacher = LinearSkillSet(skills=[
AnalyzeLLMPromptErrorsExactMatch(
name='analyze',
input_template=train_skill.input_template,
output_template='{error_report}',
initial_llm_instructions=train_skill.instructions,
prediction_column=train_skill_output,
ground_truth_column='_ground_truth',
field_schema=train_skill.field_schema,
),
ImproveLLMInstructions(
name='improve',
old_instructions=train_skill.instructions,
field_schema=train_skill.field_schema,
)
])

result = teacher.apply(skill_errors, runtime=teacher_runtime)
train_skill.instructions = result['new_instructions']
print_text(f'Updated instructions for skill "{train_skill.name}":\n')
print_text(train_skill.instructions, style='bold green')

Expand Down
4 changes: 2 additions & 2 deletions adala/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import Environment, BasicEnvironment
from .base import Environment, StaticEnvironment
from .console import ConsoleEnvironment
from .web import WebEnvironment
from .web import WebStaticEnvironment
Loading

0 comments on commit 1b69500

Please sign in to comment.