Skip to content

Commit

Permalink
Fix changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Nov 10, 2023
1 parent f5ba24c commit 89e3381
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions adala/skills/skillset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def apply(
self,
dataset: Union[Dataset, InternalDataFrame],
runtime: Runtime,
improved_skill: Optional[str] = None,
improved_skill: Optional[str] = None
) -> InternalDataFrame:
"""
Apply the skill set to a dataset using a specified runtime.
Expand All @@ -43,9 +43,7 @@ def apply(
"""

@abstractmethod
def select_skill_to_improve(
self, accuracy: Mapping, accuracy_threshold: Optional[float] = 1.0
) -> Optional[BaseSkill]:
def select_skill_to_improve(self, accuracy: Mapping, accuracy_threshold: Optional[float] = 1.0) -> Optional[BaseSkill]:
"""
Select skill to improve based on accuracy.
Expand Down Expand Up @@ -97,7 +95,7 @@ class LinearSkillSet(SkillSet):
Attributes:
skills (Union[List[str], Dict[str, str], List[BaseSkill], Dict[str, BaseSkill]]): Provided skills
skill_sequence (List[str], optional): Ordered list of skill names indicating the order
skill_sequence (List[str], optional): Ordered list of skill names indicating the order
in which they should be acquired.
By default, lexographical order of skill names is used.
input_data_field (Optional[str], optional): Name of the input data field. Defaults to None.
Expand All @@ -119,11 +117,8 @@ class LinearSkillSet(SkillSet):
skill_sequence: List[str] = None
input_data_field: Optional[str] = None

@field_validator("skills", mode="before")
@classmethod
def skills_validator(
cls, v: Union[List[str], List[BaseSkill], Dict[str, BaseSkill]]
) -> Dict[str, BaseSkill]:
@field_validator('skills', mode='before')
def skills_validator(cls, v: Union[List[str], List[BaseSkill], Dict[str, BaseSkill]]) -> Dict[str, BaseSkill]:
"""
Validates and converts the skills attribute to a dictionary of skill names to BaseSkill instances.
Expand All @@ -145,7 +140,7 @@ def skills_validator(
skills[skill_name] = LLMSkill(
name=skill_name,
instructions=instructions,
input_data_field=input_data_field,
input_data_field=input_data_field
)
# Linear skillset creates skills pipeline - update input_data_field for next skill
input_data_field = skill_name
Expand All @@ -155,7 +150,7 @@ def skills_validator(
skills[skill_name] = LLMSkill(
name=skill_name,
instructions=instructions,
input_data_field=input_data_field,
input_data_field=input_data_field
)
# Linear skillset creates skills pipeline - update input_data_field for next skill
input_data_field = skill_name
Expand All @@ -169,8 +164,8 @@ def skills_validator(
raise ValueError(f"skills must be a list or dictionary, not {type(skills)}")
return skills

@model_validator(mode="after")
def skill_sequence_validator(self) -> "LinearSkillSet":
@model_validator(mode='after')
def skill_sequence_validator(self) -> 'LinearSkillSet':
"""
Validates and sets the default order for the skill sequence if not provided.
Expand All @@ -181,11 +176,9 @@ def skill_sequence_validator(self) -> "LinearSkillSet":
# use default skill sequence defined by lexicographical order
self.skill_sequence = list(self.skills.keys())
if len(self.skill_sequence) != len(self.skills):
raise ValueError(
f"skill_sequence must contain all skill names - "
f"length of skill_sequence is {len(self.skill_sequence)} "
f"while length of skills is {len(self.skills)}"
)
raise ValueError(f"skill_sequence must contain all skill names - "
f"length of skill_sequence is {len(self.skill_sequence)} "
f"while length of skills is {len(self.skills)}")
return self

def apply(
Expand All @@ -208,9 +201,7 @@ def apply(
predictions = None
if improved_skill:
# start from the specified skill, assuming previous skills have already been applied
skill_sequence = self.skill_sequence[
self.skill_sequence.index(improved_skill) :
]
skill_sequence = self.skill_sequence[self.skill_sequence.index(improved_skill):]
else:
skill_sequence = self.skill_sequence
for i, skill_name in enumerate(skill_sequence):
Expand All @@ -223,10 +214,12 @@ def apply(
return predictions

def select_skill_to_improve(
self, accuracy: Mapping, accuracy_threshold: Optional[float] = 0.9
self,
accuracy: Mapping,
accuracy_threshold: Optional[float] = 1.0
) -> Optional[BaseSkill]:
"""
Selects the first skill in the sequence with accuracy below the threshold to improve.
Selects the skill with the lowest accuracy to improve.
Args:
accuracy (Mapping): Accuracy of each skill.
Expand All @@ -243,10 +236,8 @@ def __rich__(self):
# TODO: move it to a base class and use repr derived from Skills
text = f"[bold blue]Total Agent Skills: {len(self.skills)}[/bold blue]\n\n"
for skill in self.skills.values():
text += (
f"[bold underline green]{skill.name}[/bold underline green]\n"
f"[green]{skill.instructions}[green]\n"
)
text += f'[bold underline green]{skill.name}[/bold underline green]\n' \
f'[green]{skill.instructions}[green]\n'
return text


Expand Down

0 comments on commit 89e3381

Please sign in to comment.