Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jul 5, 2024
1 parent 129ba24 commit 7b61431
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_unit_harness_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def pytest_generate_tests(metafunc: pytest.Metafunc):
def test_model_prediction(prompt_inputs: tuple[str, str, list]):
"""Evaluates a model on a full task - is parametrized using pytest_generate_test"""
metric, task_name, examples = prompt_inputs
print(metric, task_name)
metric_name = metric
metric = Metrics[metric].value
print(metric_name, task_name)
for example in examples:
formatted_doc = {
k: v
Expand All @@ -83,7 +85,7 @@ def test_model_prediction(prompt_inputs: tuple[str, str, list]):
print(formatted_doc)
formatted_doc["query"] = formatted_doc.pop("full_prompt")
formatted_doc = Doc(**formatted_doc)
error_msg = f"Metric {metric} failed on input {formatted_doc} from task {task_name}.\n"
error_msg = f"Metric {metric_name} failed on input {formatted_doc} from task {task_name}.\n"

results = [ModelReturn(result=i, input_tokens=[], generated_tokens=[]) for i in example["predictions"]]
# todo: update to create list of ModelResults in results
Expand Down Expand Up @@ -122,23 +124,23 @@ def test_model_prediction(prompt_inputs: tuple[str, str, list]):


def apply_metric(metric, results, formatted_doc: Doc):
if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY:
if metric.category == MetricCategory.TARGET_PERPLEXITY:
_, cur_outputs = apply_target_perplexity_metric(results=results, formatted_doc=formatted_doc, metrics=[metric])
return cur_outputs
if Metrics[metric].value.category == MetricCategory.PERPLEXITY:
if metric.category == MetricCategory.PERPLEXITY:
_, cur_outputs = apply_perplexity_metric(results=results, formatted_doc=formatted_doc, metrics=[metric])
return cur_outputs
if Metrics[metric].value.category in [
if metric.category in [
MetricCategory.GENERATIVE,
MetricCategory.GENERATIVE_LOGPROB,
MetricCategory.GENERATIVE_SAMPLING,
]:
_, cur_outputs = apply_generative_metric(results=results, formatted_doc=formatted_doc, metrics=[metric])
return cur_outputs
if Metrics[metric].value.category == MetricCategory.MULTICHOICE:
if metric.category == MetricCategory.MULTICHOICE:
_, cur_outputs = apply_multichoice_metric(results=results, formatted_doc=formatted_doc, metrics=[metric])
return cur_outputs
if Metrics[metric].value.category == MetricCategory.MULTICHOICE_ONE_TOKEN:
if metric.category == MetricCategory.MULTICHOICE_ONE_TOKEN:
_, cur_outputs = apply_multichoice_metric_one_token(
results=results, formatted_doc=formatted_doc, metrics=[metric]
)
Expand Down

0 comments on commit 7b61431

Please sign in to comment.