Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run black formatter everywhere #249

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ async def arefine_skill(
predictions = await self.skills.aapply(inputs, runtime=runtime)
else:
predictions = inputs

response = await skill.aimprove(
predictions=predictions,
teacher_runtime=teacher_runtime,
Expand Down
6 changes: 3 additions & 3 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ async def initialize(self):
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
enable_auto_commit=False, # True by default which causes messages to be missed when using getmany()
auto_offset_reset="earliest",
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
group_id=self.kafka_input_topic, # ensuring unique group_id to not mix up offsets between topics
)
await self.consumer.start()

self.producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
acks='all' # waits for all replicas to respond that they have written the message
acks="all", # waits for all replicas to respond that they have written the message
)
await self.producer.start()

Expand Down
8 changes: 5 additions & 3 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def _iter_over_chunks(
input = InternalDataFrame([input])

extra_fields = self._get_extra_fields()

# if chunk_size is specified, split the input into chunks and process each chunk separately
if self.chunk_size is not None:
chunks = (
Expand All @@ -640,10 +640,12 @@ def _iter_over_chunks(
)
else:
chunks = [input]

# define the row preprocessing function
def row_preprocessing(row):
return partial_str_format(self.input_template, **row, **extra_fields, i=int(row.name) + 1)
return partial_str_format(
self.input_template, **row, **extra_fields, i=int(row.name) + 1
)

total = input.shape[0] // self.chunk_size if self.chunk_size is not None else 1
for chunk in tqdm(chunks, desc="Processing chunks", total=total):
Expand Down
32 changes: 20 additions & 12 deletions adala/skills/collection/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
logger = logging.getLogger(__name__)


def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name: str, output_field_name: str):
'''
def validate_output_format_for_ner_tag(
df: InternalDataFrame, input_field_name: str, output_field_name: str
):
"""
The output format for Labels is:
{
"start": start_idx,
Expand All @@ -23,30 +25,30 @@ def validate_output_format_for_ner_tag(df: InternalDataFrame, input_field_name:
"labels": [label1, label2, ...]
}
Sometimes the model cannot populate "text" correctly, but this can be fixed deterministically.
'''
"""
for i, row in df.iterrows():
if row.get("_adala_error"):
logger.warning(f"Error in row {i}: {row['_adala_message']}")
continue
text = row[input_field_name]
entities = row[output_field_name]
for entity in entities:
corrected_text = text[entity["start"]:entity["end"]]
corrected_text = text[entity["start"] : entity["end"]]
if entity.get("text") is None:
entity["text"] = corrected_text
elif entity["text"] != corrected_text:
# this seems to happen rarely if at all in testing, but could lead to invalid predictions
logger.warning(f"text and indices disagree for a predicted entity")
return df


def extract_indices(
df,
input_field_name,
output_field_name,
quote_string_field_name='quote_string',
labels_field_name='label'
):
df,
input_field_name,
output_field_name,
quote_string_field_name="quote_string",
labels_field_name="label",
):
"""
Give the input dataframe with "text" column and "entities" column of the format
```
Expand Down Expand Up @@ -354,7 +356,13 @@ def extract_indices(self, df):
"""
input_field_name = self._get_input_field_name()
output_field_name = self._get_output_field_name()
df = extract_indices(df, input_field_name, output_field_name, self._quote_string_field_name, self._labels_field_name)
df = extract_indices(
df,
input_field_name,
output_field_name,
self._quote_string_field_name,
self._labels_field_name,
)
return df

def apply(
Expand Down
36 changes: 22 additions & 14 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import json_schema_to_pydantic
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import (
json_schema_to_pydantic,
)

from .entity_extraction import extract_indices, validate_output_format_for_ner_tag

Expand All @@ -23,7 +25,9 @@ class LabelStudioSkill(TransformSkill):
input_template: str = "Annotate the input data according to the provided schema."
# TODO: remove output_template, fix calling @model_validator(mode='after') in the base class
output_template: str = "Output: {field_name}"
response_model: Type[BaseModel] = BaseModel # why validate_response_model is called in the base class?
response_model: Type[BaseModel] = (
BaseModel # why validate_response_model is called in the base class?
)
# ------------------------------
label_config: str = "<View></View>"

Expand All @@ -33,21 +37,21 @@ def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
#TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == 'Labels':
# TODO: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config, but confirm this - maybe move this logic to LSE
if tag.tag == "Labels":
yield tag
@model_validator(mode='after')

@model_validator(mode="after")
def validate_response_model(self):

interface = LabelInterface(self.label_config)
logger.debug(f'Read labeling config {self.label_config}')
logger.debug(f"Read labeling config {self.label_config}")

self.field_schema = interface.to_json_schema()
logger.debug(f'Converted labeling config to json schema: {self.field_schema}')
logger.debug(f"Converted labeling config to json schema: {self.field_schema}")

return self

def _create_response_model_from_field_schema(self):
pass

Expand All @@ -56,7 +60,7 @@ def apply(
input: InternalDataFrame,
runtime: Runtime,
) -> InternalDataFrame:

with json_schema_to_pydantic(self.field_schema) as ResponseModel:
return runtime.batch_to_batch(
input,
Expand All @@ -81,10 +85,14 @@ async def aapply(
response_model=ResponseModel,
)
for ner_tag in self.ner_tags():
input_field_name = ner_tag.objects[0].value.lstrip('$')
input_field_name = ner_tag.objects[0].value.lstrip("$")
output_field_name = ner_tag.name
quote_string_field_name = 'text'
quote_string_field_name = "text"
df = pd.concat([input, output], axis=1)
output = validate_output_format_for_ner_tag(df, input_field_name, output_field_name)
output = extract_indices(output, input_field_name, output_field_name, quote_string_field_name)
output = validate_output_format_for_ner_tag(
df, input_field_name, output_field_name
)
output = extract_indices(
output, input_field_name, output_field_name, quote_string_field_name
)
return output
43 changes: 29 additions & 14 deletions adala/skills/collection/prompt_improvement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import json
import logging
from pydantic import BaseModel, field_validator, Field, ConfigDict, model_validator, AfterValidator
from pydantic import (
BaseModel,
field_validator,
Field,
ConfigDict,
model_validator,
AfterValidator,
)
from adala.skills import Skill
from typing import Any, Dict, List, Optional, Union
from typing_extensions import Annotated
Expand All @@ -15,7 +22,9 @@
def validate_used_variables(value: str) -> str:
templates = parse_template(value, include_texts=False)
if not templates:
raise ValueError("At least one input variable must be used in the prompt, formatted with curly braces like this: {input_variable}")
raise ValueError(
"At least one input variable must be used in the prompt, formatted with curly braces like this: {input_variable}"
)
return value


Expand Down Expand Up @@ -52,34 +61,38 @@ class PromptImprovementSkill(AnalysisSkill):

name: str = "prompt_improvement"
instructions: str = "Improve current prompt"
input_template: str = "" # Used to provide a few shot examples of input-output pairs
input_template: str = (
"" # Used to provide a few shot examples of input-output pairs
)
input_prefix: str = "" # Used to provide additional context for the input
input_separator: str = "\n"

response_model = PromptImprovementSkillResponseModel

@model_validator(mode="after")
def validate_prompts(self):

def get_json_template(fields):
json_body = ", ".join([f'"{field}": "{{{field}}}"' for field in fields])
return "{" + json_body + "}"

if isinstance(self.skill_to_improve, LabelStudioSkill):
model_json_schema = self.skill_to_improve.field_schema
else:
model_json_schema = self.skill_to_improve.response_model.model_json_schema()

# TODO: can remove this when only LabelStudioSkill is supported
label_config = getattr(self.skill_to_improve, 'label_config', '<View>Not available</View>')
label_config = getattr(
self.skill_to_improve, "label_config", "<View>Not available</View>"
)

input_variables = self.input_variables
output_variables = list(model_json_schema['properties'].keys())
output_variables = list(model_json_schema["properties"].keys())
input_json_template = get_json_template(input_variables)
output_json_template = get_json_template(output_variables)
self.input_template = f'{input_json_template} --> {output_json_template}'
self.input_prefix = f'''
self.input_template = f"{input_json_template} --> {output_json_template}"

self.input_prefix = f"""
## Current prompt:
```
{self.skill_to_improve.input_template}
Expand All @@ -102,10 +115,12 @@ def get_json_template(fields):

## Input-Output Examples:

'''
"""

# TODO: deprecated, leave self.output_template for compatibility
self.output_template = output_json_template

logger.debug(f'Instructions: {self.instructions}\nInput template: {self.input_template}\nInput prefix: {self.input_prefix}')

logger.debug(
f"Instructions: {self.instructions}\nInput template: {self.input_template}\nInput prefix: {self.input_prefix}"
)
return self
4 changes: 3 additions & 1 deletion adala/skills/collection/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def apply(
input_strings, num_results=self.num_results
)
rag_input_strings = [
"\n\n".join(partial_str_format(self.rag_input_template, **i) for i in rag_items)
"\n\n".join(
partial_str_format(self.rag_input_template, **i) for i in rag_items
)
for rag_items in rag_input_data
]
output_fields = self.get_output_fields()
Expand Down
46 changes: 28 additions & 18 deletions adala/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_value(self, key, args, kwds):
return "{" + key + "}"
else:
Formatter.get_value(key, args, kwds)

def format_field(self, value, format_spec):
try:
return super().format_field(value, format_spec)
Expand All @@ -25,14 +25,16 @@ def format_field(self, value, format_spec):
if value.startswith("{") and value.endswith("}"):
return value[:-1] + ":" + format_spec + "}"

def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,
auto_arg_index=0):
def _vformat(
self, format_string, args, kwargs, used_args, recursion_depth, auto_arg_index=0
):
# copied verbatim from parent class except for the # HACK
if recursion_depth < 0:
raise ValueError('Max string recursion exceeded')
raise ValueError("Max string recursion exceeded")
result = []
for literal_text, field_name, format_spec, conversion in \
self.parse(format_string):
for literal_text, field_name, format_spec, conversion in self.parse(
format_string
):

# output the literal text
if literal_text:
Expand All @@ -44,18 +46,22 @@ def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,
# the formatting

# handle arg indexing when empty field_names are given.
if field_name == '':
if field_name == "":
if auto_arg_index is False:
raise ValueError('cannot switch from manual field '
'specification to automatic field '
'numbering')
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError('cannot switch from manual field '
'specification to automatic field '
'numbering')
raise ValueError(
"cannot switch from manual field "
"specification to automatic field "
"numbering"
)
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
Expand All @@ -70,19 +76,23 @@ def _vformat(self, format_string, args, kwargs, used_args, recursion_depth,

# expand the format spec, if needed
format_spec, auto_arg_index = self._vformat(
format_spec, args, kwargs,
used_args, recursion_depth-1,
auto_arg_index=auto_arg_index)
format_spec,
args,
kwargs,
used_args,
recursion_depth - 1,
auto_arg_index=auto_arg_index,
)

# format the object and append to the result
# HACK: if the format_spec is invalid, assume this field_name was not meant to be a variable, and don't substitute anything
formatted_field = self.format_field(obj, format_spec)
if formatted_field is None:
result.append('{' + ':'.join([field_name, format_spec]) + '}')
result.append("{" + ":".join([field_name, format_spec]) + "}")
else:
result.append(formatted_field)

return ''.join(result), auto_arg_index
return "".join(result), auto_arg_index


PartialStringFormat = PartialStringFormatter()
Expand Down
Loading
Loading