From 7b6143192399fb0f4fb383ef72a122972bd417c1 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Fri, 5 Jul 2024 13:07:31 +0000 Subject: [PATCH] fix tests --- tests/test_unit_harness_metrics.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_unit_harness_metrics.py b/tests/test_unit_harness_metrics.py index d8a6503ac..bbd99bf7e 100644 --- a/tests/test_unit_harness_metrics.py +++ b/tests/test_unit_harness_metrics.py @@ -73,7 +73,9 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): def test_model_prediction(prompt_inputs: tuple[str, str, list]): """Evaluates a model on a full task - is parametrized using pytest_generate_test""" metric, task_name, examples = prompt_inputs - print(metric, task_name) + metric_name = metric + metric = Metrics[metric].value + print(metric_name, task_name) for example in examples: formatted_doc = { k: v @@ -83,7 +85,7 @@ def test_model_prediction(prompt_inputs: tuple[str, str, list]): print(formatted_doc) formatted_doc["query"] = formatted_doc.pop("full_prompt") formatted_doc = Doc(**formatted_doc) - error_msg = f"Metric {metric} failed on input {formatted_doc} from task {task_name}.\n" + error_msg = f"Metric {metric_name} failed on input {formatted_doc} from task {task_name}.\n" results = [ModelReturn(result=i, input_tokens=[], generated_tokens=[]) for i in example["predictions"]] # todo: update to create list of ModelResults in results @@ -122,23 +124,23 @@ def test_model_prediction(prompt_inputs: tuple[str, str, list]): def apply_metric(metric, results, formatted_doc: Doc): - if Metrics[metric].value.category == MetricCategory.TARGET_PERPLEXITY: + if metric.category == MetricCategory.TARGET_PERPLEXITY: _, cur_outputs = apply_target_perplexity_metric(results=results, formatted_doc=formatted_doc, metrics=[metric]) return cur_outputs - if Metrics[metric].value.category == MetricCategory.PERPLEXITY: + if metric.category == MetricCategory.PERPLEXITY: _, cur_outputs = apply_perplexity_metric(results=results, formatted_doc=formatted_doc, metrics=[metric]) return cur_outputs - if Metrics[metric].value.category in [ + if metric.category in [ MetricCategory.GENERATIVE, MetricCategory.GENERATIVE_LOGPROB, MetricCategory.GENERATIVE_SAMPLING, ]: _, cur_outputs = apply_generative_metric(results=results, formatted_doc=formatted_doc, metrics=[metric]) return cur_outputs - if Metrics[metric].value.category == MetricCategory.MULTICHOICE: + if metric.category == MetricCategory.MULTICHOICE: _, cur_outputs = apply_multichoice_metric(results=results, formatted_doc=formatted_doc, metrics=[metric]) return cur_outputs - if Metrics[metric].value.category == MetricCategory.MULTICHOICE_ONE_TOKEN: + if metric.category == MetricCategory.MULTICHOICE_ONE_TOKEN: _, cur_outputs = apply_multichoice_metric_one_token( results=results, formatted_doc=formatted_doc, metrics=[metric] )