Skip to content

Commit

Permalink
Adding documentation to the environment, small readme changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Malyuk authored and Michael Malyuk committed Nov 1, 2023
1 parent 4991738 commit ef5d687
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 28 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,13 @@ adala help <command>

## 🗺 Roadmap

- [ ] Low-level skill management (i.e. agent.get_skill("name"))
- [x] Low-level skill management (i.e. agent.get_skill("name"))
- [ ] Extend environment with one more example
- [ ] Multi-task learning (learn multiple skills at once)
- [ ] Calculate and store top line Agent metrics (predictions created, runtime executions, learning loops, etc)
- [ ] Create Named Entity Recognition Skill
- [ ] Extend environment with one more example
- [ ] Command line utility (see the source for this readme for example)
- [ ] REST API to interact with Adala
- [ ] Multi-task learning (learn multiple skills at once)
- [ ] Vision and multi-modal agent skills

## 🤩 Contributing to Adala
Expand Down
150 changes: 125 additions & 25 deletions adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,44 @@


class GroundTruthSignal(BaseModel):
"""
A model that represents the comparison between predictions and ground truth data,
potentially holding information about matching results and errors per skill.
Attributes:
match (InternalDataFrame): A DataFrame indicating the correctness of predictions.
errors (Optional[Dict[str, InternalDataFrame]]): A dictionary mapping skill names to DataFrames
containing the errors between predictions and ground truth. Default is None.
"""

match: InternalDataFrame
errors: Optional[Dict[str, InternalDataFrame]] = None

def get_accuracy(self) -> InternalSeries:
"""
Calculate the accuracy of predictions as the mean of matches.
Returns:
InternalSeries: A series representing the accuracy of predictions.
"""

return self.match.mean()

def get_errors(self, skill_name: str) -> InternalDataFrame:
"""
Retrieve the errors associated with a particular skill.
Args:
skill_name (str): The name of the skill to retrieve errors for.
Returns:
InternalDataFrame: A DataFrame with two columns ["predictions", "ground_truth name"]
representing the errors.
Raises:
AssertionError: If the error DataFrame does not have exactly two columns.
"""

errors = self.errors[skill_name]
assert len(errors.columns) == 2 # ["predictions", "ground_truth name"]
return errors
Expand All @@ -34,52 +65,75 @@ class Config:


class Environment(BaseModel, ABC):
"""Abstract base class for environments.
The environment provides a mechanism to obtain ground truth information from raw data and predictions,
and also facilitates comparison of ground truth with predictions.
"""
An abstract base class that defines the structure and required methods for an environment
in which machine learning models operate and are evaluated against ground truth data.
Attributes:
Config (class): Configuration for the environment class, allows arbitrary types.
Subclasses should implement methods to handle feedback requests, comparison to ground truth,
dataset conversion, and state persistence.
"""

@abstractmethod
def request_feedback(self, skill_set: SkillSet, predictions: InternalDataFrame):
"""Request user feedback using predictions and update internal ground truth set."""
"""
Abstract method to request user feedback on the predictions made by the model.
Args:
skill_set (SkillSet): The set of skills/models whose predictions are being evaluated.
predictions (InternalDataFrame): The predictions made by the skills/models.
"""

@abstractmethod
def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalDataFrame) -> GroundTruthSignal:
"""Compare predictions with ground truth and return the results."""
"""
Abstract method to compare predictions with ground truth data.
Args:
skill_set (SkillSet): The set of skills/models whose predictions are being evaluated.
predictions (InternalDataFrame): The predictions made by the skills/models.
Returns:
GroundTruthSignal: An instance of GroundTruthSignal containing the comparison results.
"""

@abstractmethod
def as_dataset(self) -> Dataset:
"""Convert the environment to a dataset."""
"""
Abstract method to convert the environment's state into a dataset.
Returns:
Dataset: A dataset representing the environment's state.
"""

@abstractmethod
def save(self):
"""Persist the state of the environment."""
"""
Abstract method to persist the current state of the environment.
"""

@abstractmethod
def restore(self):
"""Retrieve and set the state of the environment."""
"""
Abstract method to restore the environment's state from persisted data.
"""

class Config:
arbitrary_types_allowed = True


class BasicEnvironment(Environment):
"""Basic environment implementation.
This environment assumes the ground truth is provided explicitly with the input data.
For comparison with ground truth, exact matching is used.
"""
A concrete implementation of the Environment abstract base class,
assuming the ground truth is provided and comparison is based on exact or fuzzy matching.
Attributes:
ground_truth_dataset (DataFrameDataset): Dataset containing the ground truth data.
Defaults to an empty DataFrameDataset.
ground_truth_column (str): Name of the column containing ground truth in the dataset.
Defaults to 'ground_truth'.
"""
ground_truth_dataset (Union[InternalDataFrame, DataFrameDataset]): Dataset containing
the ground truth data, defaulting to an empty DataFrameDataset.
ground_truth_columns (Dict[str, str]): A dictionary mapping skill names to their corresponding
ground truth columns in the dataset.
matching_function (str): The name of the matching function to use, defaults to 'exact'.
matching_threshold (float): The threshold for fuzzy matching, defaults to 0.8.
"""

ground_truth_dataset: Union[InternalDataFrame, DataFrameDataset] = Field(default_factory=DataFrameDataset)
ground_truth_columns: Dict[str, str]
Expand All @@ -88,12 +142,31 @@ class BasicEnvironment(Environment):

@field_validator('ground_truth_dataset')
def _validate_ground_truth_dataset(cls, v):
"""
Validate the ground_truth_dataset field to ensure it is converted to DataFrameDataset if needed.
Args:
v: The value to validate.
Returns:
The validated value, possibly converted to DataFrameDataset.
Raises:
ValidationError: If the validation fails.
"""

if isinstance(v, InternalDataFrame):
return DataFrameDataset(df=v)
return v

def request_feedback(self, skill: BaseSkill, predictions: InternalDataFrame):
"""In the BasicEnvironment, ground truth is already provided with the input data."""
"""
In the BasicEnvironment, this method is a placeholder as ground truth is already provided with the input data.
Args:
skill (BaseSkill): The skill being evaluated.
predictions (InternalDataFrame): The predictions to be reviewed.
"""

def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalDataFrame) -> GroundTruthSignal:
"""Compare the predictions with the ground truth using exact matching.
Expand All @@ -103,6 +176,18 @@ def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalData
predictions (InternalDataFrame): The predictions to compare with ground truth.
Returns:
GroundTruthSignal: The ground truth signal.
""""""
Compare the predictions with the ground truth using the specified matching function.
Args:
skill_set (SkillSet): The skill set being evaluated.
predictions (InternalDataFrame): The predictions to compare with the ground truth.
Returns:
GroundTruthSignal: The resulting ground truth signal, with matches and errors detailed.
Raises:
NotImplementedError: If the matching_function is unknown.
"""

ground_truth_match = InternalDataFrame()
Expand Down Expand Up @@ -146,16 +231,31 @@ def as_dataset(self) -> Dataset:
Returns:
Dataset: The dataset containing ground truth data.
""""""
Return the dataset containing the ground truth data.
Returns:
Dataset: The ground truth dataset as a DataFrameDataset.
"""

return self.ground_truth_dataset

def save(self):
"""Save method for BasicEnvironment. Not implemented."""
"""Save method for BasicEnvironment. Not implemented.""""""
Save the current state of the BasicEnvironment.
Raises:
NotImplementedError: This method is not implemented for BasicEnvironment.
"""

raise NotImplementedError

def restore(self):
"""Restore method for BasicEnvironment. Not implemented."""
"""Restore method for BasicEnvironment. Not implemented.""""""
Restore the state of the BasicEnvironment.
Raises:
NotImplementedError: This method is not implemented for BasicEnvironment.
"""

raise NotImplementedError

0 comments on commit ef5d687

Please sign in to comment.