Skip to content

Commit

Permalink
Add class registry for API
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Feb 19, 2024
1 parent bc912c4 commit e17a2a1
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 255 deletions.
19 changes: 12 additions & 7 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from rich import print
import yaml

from adala.environments.base import Environment, AsyncEnvironment, StaticEnvironment, EnvironmentFeedback, create_environment
from adala.runtimes.base import Runtime, AsyncRuntime, create_runtime
from adala.environments.base import Environment, AsyncEnvironment, EnvironmentFeedback
from adala.environments.static_env import StaticEnvironment
from adala.runtimes.base import Runtime, AsyncRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.runtimes import GuidanceRuntime
from adala.skills._base import Skill, AnalysisSkill, TransformSkill, SynthesisSkill
from adala.skills._base import Skill
from adala.memories.base import Memory
from adala.skills.skillset import SkillSet, LinearSkillSet
from adala.utils.logs import (
Expand All @@ -18,7 +19,7 @@
highlight_differences,
is_running_in_jupyter,
)
from adala.utils.internal_data import InternalDataFrame, InternalDataFrameConcat
from adala.utils.internal_data import InternalDataFrame


class Agent(BaseModel, ABC):
Expand Down Expand Up @@ -102,7 +103,7 @@ def environment_validator(cls, v) -> Environment:
if isinstance(v, InternalDataFrame):
v = StaticEnvironment(df=v)
elif isinstance(v, dict) and "type" in v:
v = create_environment(v.pop("type"), **v)
v = Environment.create_from_registry(v.pop("type"), **v)
return v

@field_validator("skills", mode="before")
Expand All @@ -126,8 +127,12 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
out = {}
for runtime_name, runtime_value in v.items():
if isinstance(runtime_value, dict) and "type" in runtime_value:
runtime_value = create_runtime(runtime_value.pop('type'), **runtime_value)
if isinstance(runtime_value, dict):
if "type" not in runtime_value:
raise ValueError(
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
)
runtime_value = Runtime.create_from_registry(runtime_value.pop('type'), **runtime_value)
out[runtime_name] = runtime_value
return out

Expand Down
4 changes: 3 additions & 1 deletion adala/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .base import Environment, AsyncEnvironment, StaticEnvironment, EnvironmentFeedback
from .base import Environment, AsyncEnvironment, EnvironmentFeedback
from .static_env import StaticEnvironment
from .console import ConsoleEnvironment
from .web import WebStaticEnvironment
from .code_env import SimpleCodeValidationEnvironment
from .kafka import AsyncKafkaEnvironment, FileStreamAsyncKafkaEnvironment
178 changes: 4 additions & 174 deletions adala/environments/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import pandas as pd
import numpy as np
from pydantic import BaseModel, Field
from pydantic import BaseModel
from abc import ABC, abstractmethod
from typing import Optional, Dict, Union, Callable, Type
from typing import Optional, ClassVar
from adala.utils.internal_data import (
InternalDataFrame,
InternalSeries,
InternalDataFrameConcat,
)
from adala.utils.matching import fuzzy_match
from adala.skills.skillset import SkillSet
from adala.utils.registry import BaseModelInRegistry


class EnvironmentFeedback(BaseModel):
Expand Down Expand Up @@ -56,7 +53,7 @@ def __rich__(self):
return text


class Environment(BaseModel, ABC):
class Environment(BaseModelInRegistry):
"""
An abstract base class that defines the structure and required methods for an environment
in which machine learning models operate and are evaluated against ground truth data.
Expand Down Expand Up @@ -196,170 +193,3 @@ async def restore(self):

class Config:
arbitrary_types_allowed = True


class StaticEnvironment(Environment):
"""
Static environment that initializes everything from the dataframe
and doesn't not require requesting feedback to create the ground truth.
Attributes
df (InternalDataFrame): The dataframe containing the ground truth.
ground_truth_columns ([Dict[str, str]]):
A dictionary mapping skill outputs to ground truth columns.
If not specified, the skill outputs are assumed to be the ground truth columns.
If a skill output is not in the dictionary, it is assumed to have no ground truth signal - NaNs are returned in the feedback.
matching_function (str, optional): The matching function to match ground truth strings with prediction strings.
Defaults to 'fuzzy'.
matching_threshold (float, optional): The matching threshold for the matching function.
Examples:
>>> df = pd.DataFrame({'skill_1': ['a', 'b', 'c'], 'skill_2': ['d', 'e', 'f'], 'skill_3': ['g', 'h', 'i']})
>>> env = StaticEnvironment(df, ground_truth_columns={'skill_1': 'ground_truth_1', 'skill_2': 'ground_truth_2'})
"""

df: InternalDataFrame = None
ground_truth_columns: Dict[str, str] = Field(default_factory=dict)
matching_function: Union[str, Callable] = "fuzzy"
matching_threshold: float = 0.9

def get_feedback(
self,
skills: SkillSet,
predictions: InternalDataFrame,
num_feedbacks: Optional[int] = None,
) -> EnvironmentFeedback:
"""
Compare the predictions with the ground truth using the specified matching function.
Args:
skills (SkillSet): The skill set being evaluated.
predictions (InternalDataFrame): The predictions to compare with the ground truth.
num_feedbacks (Optional[int], optional): The number of feedbacks to request. Defaults to all predictions
Returns:
EnvironmentFeedback: The resulting ground truth signal, with matches and errors detailed.
Raises:
NotImplementedError: If the matching_function is unknown.
"""

pred_columns = list(skills.get_skill_outputs())
pred_match = {}
pred_feedback = {}

if num_feedbacks is not None:
predictions = predictions.sample(n=num_feedbacks)

for pred_column in pred_columns:
pred = predictions[pred_column]
gt_column = self.ground_truth_columns.get(pred_column, pred_column)
if gt_column not in self.df.columns:
# if ground truth column is not in the dataframe, assume no ground truth signal - return NaNs
pred_match[pred_column] = InternalSeries(np.nan, index=pred.index)
pred_feedback[pred_column] = InternalSeries(np.nan, index=pred.index)
continue

gt = self.df[gt_column]

gt, pred = gt.align(pred)
nonnull_index = gt.notnull() & pred.notnull()
gt = gt[nonnull_index]
pred = pred[nonnull_index]
# compare ground truth with predictions
if isinstance(self.matching_function, str):
if self.matching_function == "exact":
gt_pred_match = gt == pred
elif self.matching_function == "fuzzy":
gt_pred_match = fuzzy_match(
gt, pred, threshold=self.matching_threshold
)
else:
raise NotImplementedError(
f"Unknown matching function {self.matching_function}"
)
elif callable(self.matching_function):
gt_pred_match = gt.combine(
pred, lambda g, p: self.matching_function(g, p)
)
pred_match[pred_column] = gt_pred_match
# leave feedback about mismatches
match_concat = InternalDataFrameConcat(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
lambda row: "Prediction is correct."
if row["match"]
else f'Prediction is incorrect. Correct answer: "{row[gt_column]}"'
if not pd.isna(row["match"])
else np.nan,
axis=1,
)

fb = EnvironmentFeedback(
match=InternalDataFrame(pred_match).reindex(predictions.index),
feedback=InternalDataFrame(pred_feedback).reindex(predictions.index),
)
return fb

def get_data_batch(self, batch_size: int = None) -> InternalDataFrame:
"""
Return the dataset containing the ground truth data.
Returns:
InternalDataFrame: The data batch.
"""
if batch_size is not None:
return self.df.sample(n=batch_size)
return self.df

def save(self):
"""
Save the current state of the StaticEnvironment.
"""
raise NotImplementedError("StaticEnvironment does not support save/restore.")

def restore(self):
"""
Restore the state of the StaticEnvironment.
"""
raise NotImplementedError("StaticEnvironment does not support save/restore.")


_environment_registry = {
"static": StaticEnvironment
}


def register_environment(name: str, environment: Type[Environment]):
"""
Register a new environment type.
Args:
name (str): The name of the environment type.
environment (Type[Environment]): The environment class to register.
"""
if name in _environment_registry:
raise ValueError(f"Environment {name} already registered.")

_environment_registry[name] = environment


def create_environment(
name: str, *args, **kwargs
) -> Environment:
"""
Create an environment of the specified type.
Args:
name (str): The name of the environment type.
*args: The arguments to pass to the environment constructor.
**kwargs: The keyword arguments to pass to the environment constructor.
Returns:
Environment: The created environment.
"""
if name not in _environment_registry:
raise ValueError(f"Unknown environment type {name}. Available types: {list(_environment_registry.keys())}")

return _environment_registry[name](*args, **kwargs)
3 changes: 2 additions & 1 deletion adala/environments/code_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import io
from contextlib import redirect_stdout, redirect_stderr
from typing import Dict, Optional, List
from adala.environments.base import StaticEnvironment, EnvironmentFeedback
from adala.environments.base import EnvironmentFeedback
from adala.environments.static_env import StaticEnvironment
from adala.skills import SkillSet
from adala.utils.internal_data import InternalDataFrame

Expand Down
6 changes: 1 addition & 5 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from io import StringIO
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
from adala.utils.internal_data import InternalDataFrame
from adala.environments import AsyncEnvironment, EnvironmentFeedback
from adala.environments.base import register_environment
from adala.environments import Environment, AsyncEnvironment, EnvironmentFeedback
from adala.skills import SkillSet
from adala.utils.logs import print_text

Expand Down Expand Up @@ -198,6 +197,3 @@ async def restore(self):
async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")


register_environment("kafka", AsyncKafkaEnvironment)
register_environment("kafka_filestream", FileStreamAsyncKafkaEnvironment)
Loading

0 comments on commit e17a2a1

Please sign in to comment.