From a3d1eea1983b306a93e913471f3312d477c54a10 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 3 Jul 2024 19:04:10 +0330 Subject: [PATCH] Download BERT scorer lazily (#190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> --- src/lighteval/metrics/imports/bert_scorer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/lighteval/metrics/imports/bert_scorer.py b/src/lighteval/metrics/imports/bert_scorer.py index dd8c0ee84..442ee9c75 100644 --- a/src/lighteval/metrics/imports/bert_scorer.py +++ b/src/lighteval/metrics/imports/bert_scorer.py @@ -375,11 +375,9 @@ def __init__( self._model_type = model_type self._num_layers = num_layers - # Building model and tokenizer - self._tokenizer = AutoTokenizer.from_pretrained(model_type) - self._model = AutoModel.from_pretrained(model_type) - self._model.eval() - self._model.to(self.device) + # Model and tokenizer are lazily loaded in `score()`. + self._tokenizer = None + self._model = None self._idf_dict = None @@ -443,6 +441,13 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): the *best* score among all references. """ + if self._model is None: + hlog(f"Loading BERTScorer model `{self._model_type}`") + self._tokenizer = AutoTokenizer.from_pretrained(self._model_type) + self._model = AutoModel.from_pretrained(self._model_type) + self._model.eval() + self._model.to(self.device) + ref_group_boundaries = None if not isinstance(refs[0], str): ref_group_boundaries = []