From ef5d687256be114f54f07e4253f8be6f9b35b21d Mon Sep 17 00:00:00 2001 From: Michael Malyuk Date: Tue, 31 Oct 2023 18:53:06 -0700 Subject: [PATCH] Adding documentation to the environment, small readme changes --- README.md | 6 +- adala/environments/base.py | 150 ++++++++++++++++++++++++++++++------- 2 files changed, 128 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index b453a17..de4bca1 100644 --- a/README.md +++ b/README.md @@ -290,13 +290,13 @@ adala help ## 🗺 Roadmap -- [ ] Low-level skill management (i.e. agent.get_skill("name")) +- [x] Low-level skill management (i.e. agent.get_skill("name")) +- [ ] Extend environment with one more example +- [ ] Multi-task learning (learn multiple skills at once) - [ ] Calculate and store top line Agent metrics (predictions created, runtime executions, learning loops, etc) - [ ] Create Named Entity Recognition Skill -- [ ] Extend environment with one more example - [ ] Command line utility (see the source for this readme for example) - [ ] REST API to interact with Adala -- [ ] Multi-task learning (learn multiple skills at once) - [ ] Vision and multi-modal agent skills ## 🤩 Contributing to Adala diff --git a/adala/environments/base.py b/adala/environments/base.py index fe3e399..abbeba7 100644 --- a/adala/environments/base.py +++ b/adala/environments/base.py @@ -10,13 +10,44 @@ class GroundTruthSignal(BaseModel): + """ + A model that represents the comparison between predictions and ground truth data, + potentially holding information about matching results and errors per skill. + + Attributes: + match (InternalDataFrame): A DataFrame indicating the correctness of predictions. + errors (Optional[Dict[str, InternalDataFrame]]): A dictionary mapping skill names to DataFrames + containing the errors between predictions and ground truth. Default is None. + """ + match: InternalDataFrame errors: Optional[Dict[str, InternalDataFrame]] = None def get_accuracy(self) -> InternalSeries: + """ + Calculate the accuracy of predictions as the mean of matches. + + Returns: + InternalSeries: A series representing the accuracy of predictions. + """ + return self.match.mean() def get_errors(self, skill_name: str) -> InternalDataFrame: + """ + Retrieve the errors associated with a particular skill. + + Args: + skill_name (str): The name of the skill to retrieve errors for. + + Returns: + InternalDataFrame: A DataFrame with two columns ["predictions", "ground_truth name"] + representing the errors. + + Raises: + AssertionError: If the error DataFrame does not have exactly two columns. + """ + errors = self.errors[skill_name] assert len(errors.columns) == 2 # ["predictions", "ground_truth name"] return errors @@ -34,52 +65,75 @@ class Config: class Environment(BaseModel, ABC): - """Abstract base class for environments. - - The environment provides a mechanism to obtain ground truth information from raw data and predictions, - and also facilitates comparison of ground truth with predictions. + """ + An abstract base class that defines the structure and required methods for an environment + in which machine learning models operate and are evaluated against ground truth data. - Attributes: - Config (class): Configuration for the environment class, allows arbitrary types. + Subclasses should implement methods to handle feedback requests, comparison to ground truth, + dataset conversion, and state persistence. """ - + @abstractmethod def request_feedback(self, skill_set: SkillSet, predictions: InternalDataFrame): - """Request user feedback using predictions and update internal ground truth set.""" + """ + Abstract method to request user feedback on the predictions made by the model. + + Args: + skill_set (SkillSet): The set of skills/models whose predictions are being evaluated. + predictions (InternalDataFrame): The predictions made by the skills/models. + """ @abstractmethod def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalDataFrame) -> GroundTruthSignal: - """Compare predictions with ground truth and return the results.""" + """ + Abstract method to compare predictions with ground truth data. + + Args: + skill_set (SkillSet): The set of skills/models whose predictions are being evaluated. + predictions (InternalDataFrame): The predictions made by the skills/models. + + Returns: + GroundTruthSignal: An instance of GroundTruthSignal containing the comparison results. + """ @abstractmethod def as_dataset(self) -> Dataset: - """Convert the environment to a dataset.""" + """ + Abstract method to convert the environment's state into a dataset. + + Returns: + Dataset: A dataset representing the environment's state. + """ @abstractmethod def save(self): - """Persist the state of the environment.""" + """ + Abstract method to persist the current state of the environment. + """ @abstractmethod def restore(self): - """Retrieve and set the state of the environment.""" + """ + Abstract method to restore the environment's state from persisted data. + """ class Config: arbitrary_types_allowed = True class BasicEnvironment(Environment): - """Basic environment implementation. - - This environment assumes the ground truth is provided explicitly with the input data. - For comparison with ground truth, exact matching is used. + """ + A concrete implementation of the Environment abstract base class, + assuming the ground truth is provided and comparison is based on exact or fuzzy matching. Attributes: - ground_truth_dataset (DataFrameDataset): Dataset containing the ground truth data. - Defaults to an empty DataFrameDataset. - ground_truth_column (str): Name of the column containing ground truth in the dataset. - Defaults to 'ground_truth'. - - """ + ground_truth_dataset (Union[InternalDataFrame, DataFrameDataset]): Dataset containing + the ground truth data, defaulting to an empty DataFrameDataset. + ground_truth_columns (Dict[str, str]): A dictionary mapping skill names to their corresponding + ground truth columns in the dataset. + matching_function (str): The name of the matching function to use, defaults to 'exact'. + matching_threshold (float): The threshold for fuzzy matching, defaults to 0.8. + """ ground_truth_dataset: Union[InternalDataFrame, DataFrameDataset] = Field(default_factory=DataFrameDataset) ground_truth_columns: Dict[str, str] @@ -88,12 +142,31 @@ class BasicEnvironment(Environment): @field_validator('ground_truth_dataset') def _validate_ground_truth_dataset(cls, v): + """ + Validate the ground_truth_dataset field to ensure it is converted to DataFrameDataset if needed. + + Args: + v: The value to validate. + + Returns: + The validated value, possibly converted to DataFrameDataset. + + Raises: + ValidationError: If the validation fails. + """ + if isinstance(v, InternalDataFrame): return DataFrameDataset(df=v) return v def request_feedback(self, skill: BaseSkill, predictions: InternalDataFrame): - """In the BasicEnvironment, ground truth is already provided with the input data.""" + """ + In the BasicEnvironment, this method is a placeholder as ground truth is already provided with the input data. + + Args: + skill (BaseSkill): The skill being evaluated. + predictions (InternalDataFrame): The predictions to be reviewed. + """ def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalDataFrame) -> GroundTruthSignal: """Compare the predictions with the ground truth using exact matching. @@ -103,6 +176,18 @@ def compare_to_ground_truth(self, skill_set: SkillSet, predictions: InternalData predictions (InternalDataFrame): The predictions to compare with ground truth. Returns: GroundTruthSignal: The ground truth signal. + """""" + Compare the predictions with the ground truth using the specified matching function. + + Args: + skill_set (SkillSet): The skill set being evaluated. + predictions (InternalDataFrame): The predictions to compare with the ground truth. + + Returns: + GroundTruthSignal: The resulting ground truth signal, with matches and errors detailed. + + Raises: + NotImplementedError: If the matching_function is unknown. """ ground_truth_match = InternalDataFrame() @@ -146,16 +231,31 @@ def as_dataset(self) -> Dataset: Returns: Dataset: The dataset containing ground truth data. + """""" + Return the dataset containing the ground truth data. + + Returns: + Dataset: The ground truth dataset as a DataFrameDataset. """ return self.ground_truth_dataset def save(self): - """Save method for BasicEnvironment. Not implemented.""" + """Save method for BasicEnvironment. Not implemented."""""" + Save the current state of the BasicEnvironment. + + Raises: + NotImplementedError: This method is not implemented for BasicEnvironment. + """ raise NotImplementedError def restore(self): - """Restore method for BasicEnvironment. Not implemented.""" + """Restore method for BasicEnvironment. Not implemented."""""" + Restore the state of the BasicEnvironment. + + Raises: + NotImplementedError: This method is not implemented for BasicEnvironment. + """ raise NotImplementedError