Skip to content

Commit

Permalink
Download BERT scorer lazily (#190)
Browse files Browse the repository at this point in the history
Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
sadra-barikbin and clefourrier authored Jul 3, 2024
1 parent a98210f commit a3d1eea
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/lighteval/metrics/imports/bert_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,9 @@ def __init__(
self._model_type = model_type
self._num_layers = num_layers

# Building model and tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(model_type)
self._model = AutoModel.from_pretrained(model_type)
self._model.eval()
self._model.to(self.device)
# Model and tokenizer are lazily loaded in `score()`.
self._tokenizer = None
self._model = None

self._idf_dict = None

Expand Down Expand Up @@ -443,6 +441,13 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
the *best* score among all references.
"""

if self._model is None:
hlog(f"Loading BERTScorer model `{self._model_type}`")
self._tokenizer = AutoTokenizer.from_pretrained(self._model_type)
self._model = AutoModel.from_pretrained(self._model_type)
self._model.eval()
self._model.to(self.device)

ref_group_boundaries = None
if not isinstance(refs[0], str):
ref_group_boundaries = []
Expand Down

0 comments on commit a3d1eea

Please sign in to comment.