diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index b7876dbc..04940d93 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -424,9 +424,7 @@ def __init__( normalize_pred (callable, optional): Function to use to normalize the predicted strings. Defaults to None if no normalization is applied. """ - self.bert_scorer = BERTScorer( - model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 - ) + self.bert_scorer = None self.normalize_gold = normalize_gold self.normalize_pred = normalize_pred @@ -441,6 +439,12 @@ def compute(self, golds: list[str], predictions: list[str]) -> dict: Returns: dict: Scores over the current sample's items. """ + if self.bert_scorer is None: + hlog_warn("The first metric computation step might be a bit long as we need to download the model.") + # We only initialize on first compute + self.bert_scorer = BERTScorer( + model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 + ) golds = as_list(golds) predictions = as_list(predictions) # Normalize