From ce2b3bf3b50015e6c8d2566eaeb3458e09de8224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 18 Jul 2024 14:40:04 +0200 Subject: [PATCH 1/2] move init of bertscorer to avoid downloading model by default --- src/lighteval/metrics/metrics_sample.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index b7876dbc..04940d93 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -424,9 +424,7 @@ def __init__( normalize_pred (callable, optional): Function to use to normalize the predicted strings. Defaults to None if no normalization is applied. """ - self.bert_scorer = BERTScorer( - model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 - ) + self.bert_scorer = None self.normalize_gold = normalize_gold self.normalize_pred = normalize_pred @@ -441,6 +439,12 @@ def compute(self, golds: list[str], predictions: list[str]) -> dict: Returns: dict: Scores over the current sample's items. """ + if self.bert_scorer is None: + hlog_warn("The first metric computation step might be a bit long as we need to download the model.") + # We only initialize on first compute + self.bert_scorer = BERTScorer( + model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 + ) golds = as_list(golds) predictions = as_list(predictions) # Normalize From a4731cdf9d95727cefefca508bc72f33f7e9c34d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine=20Fourrier?= <22726840+clefourrier@users.noreply.github.com> Date: Tue, 23 Jul 2024 14:34:28 +0200 Subject: [PATCH 2/2] Update src/lighteval/metrics/metrics_sample.py Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/metrics/metrics_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 04940d93..a240166b 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -440,7 +440,7 @@ def compute(self, golds: list[str], predictions: list[str]) -> dict: dict: Scores over the current sample's items. """ if self.bert_scorer is None: - hlog_warn("The first metric computation step might be a bit long as we need to download the model.") + hlog_warn("The first metric computation step might be a bit longer as we need to download the model.") # We only initialize on first compute self.bert_scorer = BERTScorer( model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9