Skip to content

Commit

Permalink
Removes default bert scorer init (#234)
Browse files Browse the repository at this point in the history
* move init of bertscorer to avoid downloading model by default

* Update src/lighteval/metrics/metrics_sample.py

Co-authored-by: Nathan Habib <[email protected]>

---------

Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
clefourrier and NathanHB authored Jul 24, 2024
1 parent 86fbe64 commit 2b4b637
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,7 @@ def __init__(
normalize_pred (callable, optional): Function to use to normalize the predicted strings.
Defaults to None if no normalization is applied.
"""
self.bert_scorer = BERTScorer(
model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9
)
self.bert_scorer = None

self.normalize_gold = normalize_gold
self.normalize_pred = normalize_pred
Expand All @@ -441,6 +439,12 @@ def compute(self, golds: list[str], predictions: list[str]) -> dict:
Returns:
dict: Scores over the current sample's items.
"""
if self.bert_scorer is None:
hlog_warn("The first metric computation step might be a bit longer as we need to download the model.")
# We only initialize on first compute
self.bert_scorer = BERTScorer(
model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9
)
golds = as_list(golds)
predictions = as_list(predictions)
# Normalize
Expand Down

0 comments on commit 2b4b637

Please sign in to comment.